雙周賽 132。

題目

輸入整數陣列 nums 和非負整數 k。

若一個整數序列 seq 滿足在索引範圍 [0, seq.length - 2] 中,存在最多 k 個索引滿足 seq[i] != seq[i + 1],則稱其為好的序列。

求 nums 的好的子序列的最大長度。

解法

經典的相鄰相關子序列 dp。
除了當前第 i 個元素選或不選之外,還需要紀錄上次選的元素 prev,以及相鄰不同的次數 j。

定義 dp(i, j, prev):在 nums[i..N-1] 的子陣列中,找出的最大好的子序列長度,且當前不同次數為 j,前一個元素為 prev。
轉移:dp(i, j, prev) = max(選, 不選)

  • 選,根據 nums[i] 和 prev 的關係判斷:
    • 若 prev = -1 或 nums[i] = prev,則 dp(i + 1, j, nums[i]) + 1
    • 否則若 j < k,則 dp(i + 1, j + 1, nums[i]) + 1
  • 不選:dp(i + 1, j, prev)

base:當 i = N 時,代表沒元素可選,回傳 0。

答案入口為 dp(0, 0, -1)。

時間複雜度 O(N^2 * k)。
空間複雜度 O(N^2 * k)。

class Solution:
    def maximumLength(self, nums: List[int], k: int) -> int:
        N = len(nums)
        
        @cache
        def dp(i, j, prev):
            if i == N:
                return 0
            # no take
            res = dp(i + 1, j, prev)
            # take
            if nums[i] == prev or prev == -1: # same or frist
                res = max(res, dp(i + 1, j, nums[i]) + 1)
            elif j < k: # different
                res = max(res, dp(i + 1, j + 1, nums[i]) + 1)
            return res
        
        ans = dp(0, 0, -1) 
        dp.cache_clear() # prevent MLE
        
        return ans

對於更大的測資範圍,則需要更佳的時間複雜度。
先改寫成遞推,看看什麼地方可以優化。

nums[i] 的上限高達 10^9,但受限於 nums 的大小,實際上最多也只會有 M = N 種數字。
先把 nums 離散化,dp 陣列狀態數為 N * k * M。

class Solution:
    def maximumLength(self, nums: List[int], k: int) -> int:
        N = len(nums)
        mp = {x:i for i, x in enumerate(set(nums))}
        a = [mp[x] for x in nums]
        M = len(mp)
        
        ans = 0
        dp = [[[0] * M  for _ in range(k + 1)] for _ in range(N + 1)]
        for i in reversed(range(N)):
            for j in range(k + 1):
                for prev in range(M):
                    # no take
                    res = dp[i + 1][j][prev]
                    # take
                    if a[i] == prev:
                        res = max(res, dp[i + 1][j][prev] + 1)
                    elif j < k:
                        res = max(res, dp[i + 1][j + 1][a[i]] + 1)
                    dp[i][j][prev] = res
                    ans = max(ans, res)
                    
        return ans

仔細觀察 dp[i][j][prev] 的轉移來源,除了共通的 dp[i + 1][j][prev] 以外,還有:

  • prev = nums[i] 時,dp[i + 1][j][prev] + 1
  • prev != nums[i] 且 j < k 時,dp[i + 1][j + 1][prev] + 1

設 x = nums[i],在不選 x 的情況下,dp[i][j] 會直接繼承 dp[i][j + 1] 既有的結果。
若選 x 的情況下,也只有 dp[i][j][x] 會改變,並從 dp[i + 1][j][x] 和所有 dp[i + 1][j + 1][x != prev] 之中取最大值後加 1。

基於繼承上次結果的特性,且 dp[i][j] 只依賴於 dp[i][j + 1],確保從小到大枚舉 j,就可以複用上次的結果,空間優化掉一個維度。
並且又只需要對 dp[i][j][x] 進行單點更新,枚舉 prev 的第三個迴圈也被優化掉了。


為了支持最大值的單點更新還有區間查詢,又是線段樹出場了。
建立 k + 1 個線段樹,分別維護第相鄰不同次數為 j時的區間最大值,依序枚舉 nums[i] 及次數 j,最後從所有 dp[j] 中取最大值即可。

時間複雜度 O(N * k * log M),其中 M = nums 中不同元素個數。
空間複雜度 O(k * M)。


線段樹很有用,但是:

551 / 551 test cases passed, but took too long.

複雜度代入 N = M = 5000, k = 50,大概才 3e6,反正 python 不給過,但是 golang 還有尊貴的 C++ 倒是過了。

class Solution:
    def maximumLength(self, nums: List[int], k: int) -> int:
        mp = {x:i for i, x in enumerate(set(nums))}
        M = len(mp)
        
        ans = 0
        dp = [SegmentTree(M) for _ in range(k + 1)]
        for x in nums:
            x = mp[x]
            for j in range(k + 1):
                # prev = x
                res = dp[j].query(1, 0, M - 1, x, x) + 1
                # prev != x
                if j < k:
                    res = max(res, dp[j + 1].tree[1] + 1)
                dp[j].update(1, 0, M - 1, x, res)
                ans = max(ans, res)
                    
        return ans
    
    
class SegmentTree:

    def __init__(self, n):
        self.tree = [0]*4

    def op(self, a, b):
        """
        任意符合結合律的運算
        """
        return max(a, b)

    def push_up(self, id):
        """
        以左右節點更新當前節點值
        """
        self.tree[id] = self.op(self.tree[id*2], self.tree[id*2+1])

    def query(self, id, L, R, i, j):
        """
        區間查詢
        回傳[i, j]的最大值
        """
        if i <= L and R <= j:  # 當前區間目標範圍包含
            return self.tree[id]
        res = 0
        M = (L+R)//2
        if i <= M:
            res = self.op(res, self.query(id*2, L, M, i, j))
        if M+1 <= j:
            res = self.op(res, self.query(id*2+1, M+1, R, i, j))
        return res

    def update(self, id, L, R, i, val):
        """
        單點更新
        對索引i設為val
        """
        if L == R:  # 當前區間目標範圍包含
            self.tree[id] = val
            return
        M = (L+R)//2
        if i <= M:
            self.update(id*2, L, M, i, val)
        else:
            self.update(id*2+1, M+1, R, i, val)
        self.push_up(id)

可能有細心的同學會問:不是從 dp[i + 1][j + 1][x != prev] 轉移而來嗎?怎麼區間查詢包含了 dp[i + 1][j + 1][x = prev]?
其實照理說是不能包含這一塊,應該要分成 x 的左右兩半區間查詢。
但是 dp[i + 1][j + 1][x] 比起 dp[i + 1][j][x] 少了一次不同的機會,不可能得到更好的結果,永遠不會影響答案,所以可以不用管他。

最初我也沒想清楚這點,所以才會選擇線段樹。


再認真想一想,其實還有可以優化的地方。

對於每個 dp[j],真正需要查詢的只有單點最大值整體最大值,並沒有部分區間,也就是說根本不需要線段樹
只需要單獨維護 dp[j][prev] 的值,還有整個 dp[j] 的最大值。

時間複雜度 O(N * k)。
空間複雜度 O(k * M)。

class Solution:
    def maximumLength(self, nums: List[int], k: int) -> int:
        mp = {x:i for i, x in enumerate(set(nums))}
        M = len(mp)
        
        ans = 0
        dp = [[0] * M for _ in range(k + 1)] # dp[j][prev]
        dpj = [0] * (k + 1) # max(dp[j])
        for x in nums:
            x = mp[x]
            for j in range(k + 1):
                # prev = x
                res = dp[j][x] + 1
                # prev != x
                if j < k:
                    res = max(res, dpj[j + 1] + 1)
                dpj[j] = max(dpj[j], res)
                dp[j][x] = res
                ans = max(ans, res)
                    
        return ans