每日題。可以優化超級多次dp練習題。單純通過不難,難的是找出最佳解。

題目

輸入三個整數n, m和k。
下列演算法用來來找出正整數陣列的最大元素:

img

你必須構造出滿足以下條件的出陣列arr:

  • arr擁有n個整數
  • 1 <= arr[i] <= m 其中 (0 <= i < n)
  • 使用上述的演算法後,search_cost會等於k

求有多少種構造出arr的方案。答案很大,先模10^9+7後回傳。

解法

search_cost以下簡稱成本
題目給定的演算法,就是遍歷陣列找到最大值。最大值初始為-1,成本就是最大值更新的次數。

構造出長度為n的陣列,只能由1~m的元素組成,而成本必須是k。
先前最大值會影響成本的變化。
目前知道有三個變數:陣列長度、先前最大值、成本。

定義dp(i,j,k):長度為i的陣列,先前最大值為j,且成本為k的陣列的構造方案數

我們可以任選1~m的任意元素,則會縮減問題的規模,形成較小的子問題dp(i-1,j’,k’)。
而新的最大值j’取max(j,x)。所有小於等於j的元素x都不會使成本增加,因此子問題為dp(i-1,j,k);
大於j的元素x則會改變最大值,且成本增加1,因此子問題為dp(i-1,x,k-1)。
對於dp(i,j,k)來說,共有1~j共j個元素小於等於j,其他都大於j。
轉移方程式:dp(i,j,k) = dp(i-1,j,k)*j + sum( dp(i-1,x,k-1) FOR ALL j<x<=m )
base cases:剩餘長度i=0,且需求成本k=0時,順利達成要求,答案為1;若長度或成本不足0,代表不合法的狀態,答案為0。

狀態有三個參數,共有nmk種狀態。每個狀態轉移需要O(m)時間。

時間複雜度O(n * m^2 * k)。
空間複雜度O(n * m * k)。

class Solution:
    def numOfArrays(self, n: int, m: int, k: int) -> int:
        MOD=10**9+7
        
        @cache
        def dp(i,j,k):
            if i==0 and k==0:
                return 1
            if i<0 or k<0:
                return 0
            res=dp(i-1,j,k)*j
            for x in range(j+1,m+1):
                res+=dp(i-1,x,k-1)
            return res%MOD
    
        return dp(n,0,k)

我們發現i和k這兩個變數會出現負數,為了方便改寫成遞推版本,將狀態上的i和k都增加位移量1,使所有狀態都不為負。
當然,入口函數也要一起增加。

class Solution:
    def numOfArrays(self, n: int, m: int, k: int) -> int:
        MOD=10**9+7
        
        @cache
        def dp(i,j,k):
            if i==1 and k==1:
                return 1
            if i<1 or k<1:
                return 0
            res=dp(i-1,j,k)*j
            for x in range(j+1,m+1):
                res+=dp(i-1,x,k-1)
            return res%MOD
    
        return dp(n+1,0,k+1)

迴圈使用到的k會和輸入的k相同,原本的k記作k0,避免數值被汙染。

初始化狀態將三個狀態參數轉換成迴圈、然後return改成continue,剩餘照搬就可以。
dp(i,j,k)是從dp(i-1,j’,k’)轉移過來,因此只需要確保i是最外層迴圈,從小到大枚舉,剩餘兩者隨意。

class Solution:
    def numOfArrays(self, n: int, m: int, k: int) -> int:
        MOD=10**9+7
        k0=k
        dp=[[[0]*(k+2) for _ in range(m+1)] for _ in range(n+2)]
        
        for i in range(n+2):
            for j in range(m+1):
                for k in range(k0+2):
                    if i==1 and k==1:
                        dp[i][j][k]=1
                        continue
                    if i<1 or k<1:
                        continue
                    res=dp[i-1][j][k]*j
                    for x in range(j+1,m+1):
                        res+=dp[i-1][x][k-1]
                    dp[i][j][k]=res%MOD
                    
        return dp[n+1][0][k0+1]

又發現當i或j等於0時,方案數一定是0,根本不用處理。
因此i和j可從1開始枚舉。

class Solution:
    def numOfArrays(self, n: int, m: int, k: int) -> int:
        MOD=10**9+7
        k0=k
        dp=[[[0]*(k+2) for _ in range(m+1)] for _ in range(n+2)]
        
        for i in range(1,n+2):
            for j in range(m+1):
                for k in range(1,k0+2):
                    if i==1 and k==1:
                        dp[i][j][k]=1
                        continue
                    res=dp[i-1][j][k]*j
                    for x in range(j+1,m+1):
                        res+=dp[i-1][x][k-1]
                    dp[i][j][k]=res%MOD
                    
        return dp[n+1][0][k0+1]

仔細觀察,又又發現當i和k不變的情況下,只要j越大,轉移的次數會越小。
對dp(i,j,k)來說,x轉移的範圍是j<x<=m;
對dp(i,j-1,k)來說,x轉移的範圍是j-1<x<=m,比起上者,只多出一個來源dp(i-1,j,k)。
也就是說這部分有非常多的重複計算。

這些重複計算怎麼處理?
如同先讓你找1~1的總和,再找1~2的總和,再找1~3的總和同理,重複利用先前的值:正是前綴和
維護一個變數ps,作為先前x來源的前綴和,直接加入ps就可以,不需要枚舉所有m個元素。
計算完dp(i,j,k)後,記得dp(i-1,j,k-1)加入前綴和中,為之後的j作貢獻。

注意:因為上述更動,dp(i,j-1,k)依賴於dp(i,j,k)的轉移前綴和,因此j必須從大到小枚舉。
並且k都是維持同樣的值,所以必須把k的迴圈搬到j的外層。

時間複雜度O(n * m * k)。
空間複雜度O(n * m * k)。

class Solution:
    def numOfArrays(self, n: int, m: int, k: int) -> int:
        MOD=10**9+7
        k0=k
        dp=[[[0]*(k+2) for _ in range(m+1)] for _ in range(n+2)]
        
        for i in range(1,n+2):
            for k in range(1,k0+2):
                ps=0
                for j in reversed(range(m+1)):
                    if i==1 and k==1:
                        dp[i][j][k]=1
                        continue
                    res=dp[i-1][j][k]*j+ps
                    dp[i][j][k]=res%MOD
                    ps+=dp[i-1][j][k-1]
                    
        return dp[n+1][0][k0+1]

這樣就結束了嗎?並沒有。

dp(i,j,k)只依賴於i-1的結果,別忘了遞推空間優化的老朋友:滾動陣列
計算長度為i的方案時,只需要保留前一次i-1的結果,壓縮掉n這個維度的空間。

時間複雜度O(n * m * k)。
空間複雜度O(m * k)。

class Solution:
    def numOfArrays(self, n: int, m: int, k: int) -> int:
        MOD=10**9+7
        k0=k
        dp=[[0]*(k+2) for _ in range(m+1)]
        
        for i in range(1,n+2):
            dp2=[[0]*(k+2) for _ in range(m+1)]
            for k in range(1,k0+2):
                ps=0
                for j in reversed(range(m+1)):
                    if i==1 and k==1:
                        dp2[j][k]=1
                        continue
                    res=dp[j][k]*j+ps
                    dp2[j][k]=res%MOD
                    ps+=dp[j][k-1]
            dp=dp2
                    
        return dp[0][k0+1]

最後再看看k的變化,對於dp(i,j,k)只會依賴於dp(i-1,j,k)和dp(i-1,j,k-1)。
其實連滾動的dp2陣列都不需要,只要把k逆向枚舉,就不會覆蓋到需要的值。

class Solution:
    def numOfArrays(self, n: int, m: int, k: int) -> int:
        MOD=10**9+7
        k0=k
        dp=[[0]*(k+2) for _ in range(m+1)]
        
        for i in range(1,n+2):
            for k in reversed(range(1,k0+2)):
                ps=0
                for j in reversed(range(m+1)):
                    if i==1 and k==1:
                        dp[j][k]=1
                        continue
                    res=dp[j][k]*j+ps
                    dp[j][k]=res%MOD
                    ps+=dp[j][k-1]
                    
        return dp[0][k0+1]

搞這麼多次,總算大功告成。