雙周賽 129。

題目

輸入三個正整數 zero, one 和 limit。

一個穩定的陣列 arr 滿足:

  • 數字 0 正好出現 zero 次
  • 數字 1 正好出現 one 次
  • 每個長度大於 limit 的子陣列必須擁有 0 和 1

求有多少穩定的二進位陣列。
答案可能很大,先模 10^9 + 7 後回傳。

解法

上一篇提到 dp(i, j, use) 需要枚舉當前數字選多少個,每個狀態需要枚舉 1~ limit。

例如 limit = 2 時:

dp(i, j, 0) 轉移 = dp(i - 1, j, 1) + dp(i - 2, j, 1)
dp(i - 1, j, 0) 轉移 = dp((i - 1) - 1, j, 1) + dp((i - 1) - 2, j, 1)
dp(i - 2, j, 0) 轉移 = dp((i - 2) - 1, j, 1) + dp((i - 2) - 2, j, 1)

可以發現,對於相同的 j 來說,轉移來源會有部分重疊。
這時可以用前綴和,將 O(limit) 的轉移優化成 O(1)。

時間複雜度 O(zero * one)。
空間複雜度 O(zero * one)。

其實這樣複雜度應該是合格,無奈測資很兇,還是會超時。

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        MOD = 10 ** 9 + 7
        
        @cache
        def dp(i, j, use):
            if i == 0 and j == 0:
                return 1
            
            if use == 0: # use 0
                # for x in range(1, min(i, limit) + 1):
                #     res += dp(i - x, j, 1)
                res = ps(i - 1, j, 1) - ps(i - (limit + 1), j, 1)
            else: # use 1
                # for x in range(1, min(j, limit) + 1):
                #     res += dp(i, j - x, 0)
                res = ps(i, j - 1, 0) - ps(i, j - (limit + 1), 0)
            return res % MOD
        
        @cache
        def ps(i, j, use):
            if i < 0 or j < 0:
                return 0
            if use == 0: # sum for dp(i, j - x, 0)
                res = dp(i, j, 0) + ps(i, j - 1, 0)
            else: # sum for dp(i - x, j, 1)
                res = dp(i, j, 1) + ps(i - 1, j, 1)
            return res % MOD
        
        ans = dp(zero, one, 0) + dp(zero, one, 1)
        dp.cache_clear()
        ps.cache_clear()
        
        return ans % MOD

上一篇也提到,基於對稱性,可以省略掉參數 use。前綴和也是如此。
雖然還是很慢,但至少能過了。

遞推版本是真的難寫,等哪天我想通再回來補課。

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        MOD = 10 ** 9 + 7
        
        @cache
        def dp(i, j): # use i
            if i < 0 or j < 0:
                return 0
            if i == 0 and j == 0:
                return 1
            # res = 0
            # for k in range(1, limit + 1):
            #     res += dp(j, i - k)
            res = ps(j, i - 1) - ps(j, i - (limit + 1)) # reverse (i - x, j) to (j, i - x)
            return res % MOD
          
        @cache
        def ps(i, j): # sum of dp(i, j) .. dp(i, 0)
            if j < 0:
                return 0
            res = dp(i, j) + ps(i, j - 1) 
            return res % MOD
        
        ans = dp(zero, one) + dp(one, zero)
        dp.cache_clear()
        ps.cache_clear()
        
        return ans % MOD

以下提供另一種思路,但解釋的很爛,請注意。

因為受到 limit 的約束,在 dp 時需要帶一個參數 cnt 來代表連續的次數,才能得知那些選擇不合法。
但是當 zeor, one <= 1000,光是這兩個的狀態數就高達 N^2,再加上 cnt 肯定沒戲。得想辦法優化掉。

先來看看不帶 cnt 會多算那些不合法的東西。
根據定義,dp(i, j, use) 是指當前最後一個數選擇 use 的合法方案 (第 i + j 個數選 use)。
既然這些方案是合法的,那他必然要從合法的子問題中轉移而來。也就是說,先前的 (i + j - 1) 個數中,最多連續出現 limit 次的相同字元。

以 limit = 2, dp(3, 1, 0) 為例。
當前必須選擇 0,而先前的數可能是 0 或 1 結尾。轉移來源有:

以 0 結尾的 dp(2, 1, 0) 有:
010, 100
以 1 結尾的 dp(2, 1, 1) 有:
001

當前要選的是 0,所以從 1 結尾的方案轉移過來肯定沒問題,反正兩個數不同。
從 0 轉移就有點問題:

010 變成 0100 合法
100 變成 1000 不合法


想了好久,終於想出個自己能夠接受的解釋。

首先想清楚 dp(i, j, 0) 的定義是什麼?
填 i 個 0 和 j 個 1 ,且最多連續 limit 次的合法方案數。而且最後一個數是選 0。

xxx0

如果從 dp(i - 1, j, k=0/1) 轉移到 dp(i, j, 0) 代表著什麼?
從填 i - 1 個 0 和 j 個 1 的所有合法方案中,在最後面加上一個 0。

xx00
xx10

dp(i - 1, j, k=0/1) 又會從各自來源轉移。以此類推,直到數字用完為止。

x000
x010
x100
x110


dp(i, j) 的定義是合法方案數,這很重要所以一直重複講!!

再次回到 limit = 2, dp(3, 1, 0) 的例子。
其轉移來源 dp(2, 1, 0) 的合法方案有:

010, 100

其中 100 會轉移過去之後會變成非法,因為他連續超過 limit 次。
如果 dp(i - 1, j) 是 dp(i, j) 填了一個 0 的方案數,那麼 dp(i, j) 填了 (limit + 1) 個 0 的方案數就是 dp(i - (limit + 1), j)。
既然已經知道這樣填不合法,那直接把他扣掉就行。

時間複雜度 O(zero * one)。
空間複雜度 O(zero * one)。

class Solution:
    def numberOfStableArrays(self, zero: int, one: int, limit: int) -> int:
        MOD = 10 ** 9 + 7
        
        @cache
        def dp(i, j, use):
            if i < 0 or j < 0:
                return 0
            
            if i == 0:
                if use == 1 and j <= limit:
                    return 1
                else:
                    return 0
                
            if j == 0:
                if use == 0 and i <= limit:
                    return 1
                else:
                    return 0
            
            if use == 0: # use 0 
                res = dp(i - 1, j, 0) + dp(i - 1, j, 1) 
                res -= dp(i - (limit + 1), j, 1) # no more than limit 
            else: # use 1
                res = dp(i, j - 1, 0) + dp(i, j - 1, 1)
                res -= dp(i, j - (limit + 1), 0) # no more than limit 
            return res % MOD
        
        ans = dp(zero, one, 0) + dp(zero, one, 1)
        dp.cache_clear()
        
        return ans % MOD