複習以前寫過比較難搞的回溯題。加了新測資,舊方法直接變成TLE,只好找找別的出路。

題目

輸入整數陣列nums和整數k,如果可以把nums分成k個總和相等的非空子集合,則回傳true。

解法

先講舊方法,複雜度為O(k^N),現在已經TLE,無法使用。

要想平分成k個子集合,得先確認nums的總和是否能被k整除,若無法整除直接回傳false。
建立長度k的陣列sub,代表各個子集合,使用回溯法暴力搜尋,試著將每個數字nums[i]加到各個子集合中。
nums長度為N,從i=0開始進入,直到i等於N時,代表所有數字都已經分配到不同子集合中,這時檢查k個子集合是否都為平均值avg,若是則更新答案為true。

class Solution:
    def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
        if sum(nums) % k:
            return False
        avg = sum(nums)//k
        N = len(nums)
        sub = [0]*k
        flag = False

        def bt(n):
            nonlocal flag
            if n == N:
                for x in sub:
                    if x != avg:
                        return
                flag = True
            else:
                for i in range(k):
                    if sub[i]+nums[n] <= avg:
                        sub[i] += nums[n]
                        bt(n+1)
                        sub[i] -= nums[n]

        bt(0)

        return flag

上面的回溯一次嘗試組成k個子集合,指數上去成長太快了。
參考了這篇的第一個方法,改成一次只組成一個總合為avg的子集合,成功後才嘗試下一個,直到滿足k個為止。

我們不在乎子集合的組成與順序,只要他總和等於avg就行,那麼可以先將nums以遞減排序,盡可能先使用較大的數,減少無謂的分支。
其實這有點像是top down dp的感覺了。 定義回溯函數bt(i,curr,k):i為nums中的起始索引,curr為當前子集合總和,k為剩下需要的子集合。
base case:k為1時,剩下的數字剛好可以湊成avg,直接回傳true。
否則從索引i開始,試著將i~N-1的未使用數字加入當前子集合中。子集合總和剛好為avg時,則遞迴建立下一個子集合。

和其他回溯題型一樣,本題也有關鍵的剪枝加速點:
因為我們是從nums[0]開始循序嘗試到nums[N-1],若nums[i]和nums[i-1]數值都是5,且5可以用來組成子集合的話,nums[i-1]一定會先被使用掉。
所以碰到nums[i]等於nums[i-1],且nums[i-1]沒被使用的情況下,可以知道nums[i-1]無法加入子集合,那麼和nums[i-1]同值的nums[i]當然也無法被加入,所以直接不處理nums[i]。

class Solution:
    def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
        sm=sum(nums)
        if sm%k!=0:
            return False
        N=len(nums)
        nums.sort(reverse=1)
        used=[False]*N
        avg=sm//k
        
        def bt(i,curr,k):
            if k==1:
                return True
            if curr==avg:
                return bt(0,0,k-1)
            for j in range(i,N):
                if used[j] or curr+used[j]>avg:
                    continue
                if j>0 and nums[j]==nums[j-1] and not used[j-1]:
                    continue
                used[j]=True
                if bt(j+1,curr+nums[j],k):
                    return True
                used[j]=False
            return False
                
        return bt(0,0,k)

更新bitmask DP解法。看了中文站才發現遺漏一個重要的剪枝條件:nums中不可以有超過avg的數字。

以bitmask表示那些數字已經用過了,從0開始bottom up,每次挑一個沒用過數字試著加入。
當前子集合的總和記錄在val陣列裡面,對val[mask]的值模avg,可以計算出當前正在組成的子集合總和。

class Solution:
    def canPartitionKSubsets(self, nums: List[int], k: int) -> bool:
        sm=sum(nums)
        if sm%k!=0: # 剪枝,必須能夠整除
            return False
        avg=sm//k
        for x in nums: # 剪枝,數字不可超過avg
            if x>avg:
                return False
        nums.sort()
        N=len(nums)
        dp=[False]*(1<<N)
        dp[0]=True
        val=[0]*(1<<N)
        
        for mask in range(1<<N):
            if not dp[mask]: # 須以合法的組合為基礎
                continue
            for i in range(N):
                newMask=mask|(1<<i)
                if newMask!=mask: # 沒用過的數字
                    subSum=(val[mask]%avg)+nums[i] 
                    if subSum<=avg: # 且可以放入集合
                        dp[newMask]=True
                        val[newMask]=subSum
                    else: # 剪枝,後面元素都不可用了
                        break
        
        return dp[-1]