雙周賽115。花了好多天才搞懂,這題細節也不少。

題目

輸入非負整數陣列nums,還有兩個整數l和r。

求nums有多少子多重集合,其子集元素總和正好落在區間[l, r]中。
答案很大,先模10^9+7後回傳。

子多重集合指的是一個無序的元素集合,其中元素x最多可以出現occ[x]次,而occ[x]是x在nums中的出現次數。

注意:

  • 若兩個子多重集合排序後相同,則視為同個子多重集合
  • 集合的總和為0

解法

子多重集合看起來很囉嗦,其實就是多重背包問題。
元素x有cnt[x]個,看你要選幾個,最後求總和落在[l, r]中的選法有幾種。

這題測資範圍很奇妙,說nums長度上限2*10^4,然後S=sum(nums)和max(nums)上限也都是2*10^4。
只有在nums[i]全部都是1或0或2的情況下,才能達到nums的長度上限。
那如果nums[i]的元素全都是不同的,必須滿足1+2+…+a <= sum(nums),能有sqrt(sum(nums))種,大概是一百多。

樸素版本的多重背包問題很簡單就能寫出來。

定義dp(i,j):在剩餘前i種元素時,湊出總和為j的選法有幾種。
轉移方程式:dp(i,j) = sum( dp(i-1,j-k*x) FOR ALL 0<=k<=cnt),其中k*x不可超過j。
base case:當i<0時,沒有剩餘元素可選,只有空集合一種選擇。如果總和要求j剛好為0,答案為1;否則不合法,回傳0。

共有min(sqrt(S),N)種元素,總和有min(S,r)種。每個狀態轉移最多N次。
時間複雜度O(min(sqrt(S),N) * min(S,r) * N)。
空間複雜度O(min(sqrt(S),N) * min(S,r))。

計算量隨隨便便都10^8,嚴重TLE。

class Solution:
    def countSubMultisets(self, nums: List[int], l: int, r: int) -> int:
        MOD=10**9+7

        # no more than S
        r=min(r,sum(nums))

        # remaining elements
        d=Counter(nums)
        keys=list(d)
        N=len(d)
        
        @cache
        def dp(i,j):
            if i<0:
                return int(j==0)
            res=0
            x=keys[i]
            for k in range(d[x]+1):
                if x*k>j:
                    break
                res+=dp(i-1,j-x*k)
            return res%MOD
     
        ans=0
        for i in range(l,r+1):
            ans+=dp(N-1,i)

        return ans%MOD

假設在第i種元素為x,共有cnt個:

dp(i,j) = dp(i-1,j) + dp(i-1,j-x) + … + dp(i-1,j-cnt*x)

列出另一項比較:

dp(i,j-x) = dp(i-1,j-x) + dp(i-1,j-x*2) + … + dp(i-1,j-(cnt+1)*x)

發現dp(i,j)相對於dp(i,j-x),多了dp(i-1,j),少了dp(i-1,j-(cnt+1)*x)。
轉移方程式變形成:

dp(i,j) = dp(i,j-x) + dp(i-1,j) - dp(i-1,j-(cnt+1)*x)

但是多出一個例外:如果x=0, dp(i,j-x)會無限遞迴下去。
0不管拿幾個總和都是0。除了空集合以外,還有cnt種選法可以組成總和0。所以最後答案要乘上cnt+1。

注意:LeetCode評測機有點問題,一定要把快取清掉,不然會MLE。

時間複雜度O(min(sqrt(S),N) * min(S,r))。
空間複雜度O(min(sqrt(S),N) * min(S,r))。

class Solution:
    def countSubMultisets(self, nums: List[int], l: int, r: int) -> int:
        MOD=10**9+7
        S=sum(nums)
        
        # less than l
        if S<l:
            return 0
        
        # no more than S
        r=min(r,S)
        
        # special case of 0
        d=Counter(nums)
        zeros=d[0]+1
        del d[0]
        
        # remaining elements
        keys=list(d)
        N=len(d)
        
        @cache
        def dp(i,j):
            if i<0 and j==0:
                return 1
            if i<0 or j<0:
                return 0
            x=keys[i]
            cnt=d[x]
            res=dp(i,j-x)+dp(i-1,j)-dp(i-1,j-x*(cnt+1))
            return res%MOD
     
        ans=0
        for i in range(l,r+1):
            ans+=dp(N-1,i)
            ans%=MOD
        
        dp.cache_clear() # prevent MLE
        
        return ans*zeros%MOD

改成遞推版本。

class Solution:
    def countSubMultisets(self, nums: List[int], l: int, r: int) -> int:
        MOD=10**9+7
        S=sum(nums)
        
        # less than l
        if S<l:
            return 0
        
        # no more than S
        r=min(r,S)
        
        # special case of 0
        d=Counter(nums)
        zeros=d[0]+1
        del d[0]
        
        # remaining elements
        keys=list(d)
        N=len(d)
        
        dp=[[0]*(r+1) for _ in range(N+1)]
        dp[0][0]=1
        
        for i in range(N):
            x=keys[i]
            cnt=d[x]
            for j in range(r+1):
                dp[i+1][j]=dp[i][j]
                if j>=x:
                    dp[i+1][j]+=dp[i+1][j-x]
                if j>=x*(cnt+1):
                    dp[i+1][j]-=dp[i][j-x*(cnt+1)]
                dp[i+1][j]%=MOD
    
        ans=0
        for i in range(l,r+1):
            ans+=dp[N][i]
            ans%=MOD
        
        return ans*zeros%MOD

dp(i,j)只會參考到dp(i-1,j)和左方的dp(i,j-x)、dp(i,j-x*(cnt+1)),因此可以只保留上一列的結果,只使用兩個陣列。

時間複雜度O(min(sqrt(S),N) * min(S,r))。
空間複雜度O(min(S,r))。

class Solution:
    def countSubMultisets(self, nums: List[int], l: int, r: int) -> int:
        MOD=10**9+7
        S=sum(nums)
        
        # less than l
        if S<l:
            return 0
        
        # no more than S
        r=min(r,S)
        
        # special case of 0
        d=Counter(nums)
        zeros=d[0]+1
        del d[0]
        
        # remaining elements
        keys=list(d)
        N=len(d)
        
        dp=[0]*(r+1)
        dp[0]=1
        
        for i in range(N):
            x=keys[i]
            cnt=d[x]
            dp2=[0]*(r+1)
            for j in range(r+1):
                dp2[j]=dp[j]
                if j>=x:
                    dp2[j]+=dp2[j-x]
                if j>=x*(cnt+1):
                    dp2[j]-=dp[j-x*(cnt+1)]
                dp2[j]%=MOD
            dp=dp2
            
        ans=0
        for i in range(l,r+1):
            ans+=dp[i]
            ans%=MOD
        
        return ans*zeros%MOD