weekly contest 422。
本篇題解寫得不太好,老實說我也不太確定正確性,建議隨便看看就好。

題目Permalink

輸入字串 num。
如果字串偶數索引數字的和等於奇數索引數字的和,則稱為平衡的

求 nums 的不同排列中,有多少是平衡的

答案可能很大,先模 10^9 + 7 後回傳。

解法Permalink

延續 Q1 的平衡字串

奇偶索引是交替出現。所以偶數索引若有 x 個,則奇數索引只可能是 x 或 x-1 個。
問題轉換成:把 N 個數字分成兩堆,各 sz1 + sz2 個,且兩堆的和相同
因為要平分成兩堆,所以 sum(num) 必須是偶數,若非偶數可直接回傳 0。

在 N 個數中找 sz 個數,使總和正好為 target,我馬上想到 494. Target Sum
只是多加了一個選擇數量限制而已。


再來想想怎麼求排列。

一個長度為 x 的字串,全排列共有 x! 種。例如 “112” 的全排列為 3! 種。
若要去除重複,則必須除去每個元素數量的階乘。例如 “112” 中有 2 個 1、1個 2,所以要除 1! 和 2!。

舉個實際例子:

構造字串長度為 5,且和為 20
全排列共有 5! 種
從 0~9 依序枚舉要選幾個
假設選 1 個 1,子問題變成字串長度 4,和為 19,除去重複的 1!
假設選 2 個 1,子問題變成字串長度 3,和為 18,除去重複的 2!
以此類推

不同選法也可能造成相同的限制,有重疊的子問題,因此考慮 dp。

定義 dp(i, cnt, val):在 i~9 的數字之中,選擇 cnt 個數字且總合為 val 的不重複方案數。
轉移:dp(i, cnt, val) = sum( (dp(i+1, cnt-j, val-i*j) / j!) FOR ALL 0 <= j <= min(freq[i], cnt))。
base:當 i = 10 時,所有數字都選完了,只有在 cnt = val = 0 時才滿足要求,回傳字串的全排列,即 sz!。


根據乘法原理,答案應是兩堆的方案數相乘。
但上面的 dp 狀態只能處理其中一堆的方案數,要如何知道另一堆的方案數?

這邊我卡了很久沒想通,因為 dp 不知道到底選了哪些數,一度以為不是正確做法,但只是我想多了。
舉個例子:

num = “112112”
很明顯兩堆分別是 “112” 和 “112”
“112” 的方案數是 3! / 2! / 1! = 3
因此答案是 3 * 3 = 9 種

注意每個數都必須在其中一堆,也就是說對於這 4 個 1 來說,若在第一堆用了 2 個,剩餘的 4-2 個必定在第二堆
因此兩堆 sz1 和 sz2 的全排列方案數共有 sz! * sz2! 種,而在 dp 轉移時,除了扣除當前選擇的 j! 種排列以外,也要順便扣除另一堆的 (freq[i] - j)! 種排列。


剩下最後一個問題,由於取模後的方案數涉及除法,必須使用乘法逆元,才能確保答案的正確性。
答案入口為 dp(0, sz1, target),其中 target = sum(num) / 2。

時間複雜度不太確定,就不亂寫了。
空間複雜度 O(10 * N * S)。

MOD = 10 ** 9 + 7
MX = 85
f = [0]*(MX+1)
finv = [0]*(MX+1)
f[0] = finv[0] = 1
f[1] = finv[1] = 1

for i in range(2, MX+1):
    f[i] = (f[i-1]*i) % MOD
    finv[i] = pow(f[i], -1, MOD)

class Solution:
    def countBalancedPermutations(self, num: str) -> int:
        N = len(num)

        d = Counter()
        tot = 0
        for c in num:
            d[int(c)] += 1
            tot += int(c)

        if tot%2 == 1:
            return 0

        target = tot // 2
        sz1 = N // 2
        sz2 = N - sz1

        @cache
        def dp(i, cnt, val):
            if i == 10:
                if cnt == 0 and val == 0:
                    return f[sz1] * f[sz2] # factorial(sz1) * factorial(sz2) 
                return 0

            res = 0
            for j in range(d[i] + 1):
                if j > cnt or i*j > val:
                    break
                t = dp(i+1, cnt-j, val - i*j)
                t *= finv[j] # /= factorial(j)
                t *= finv[d[i]-j] # /= factorial(d[i]-j)
                res += t
            return res % MOD

        return dp(0, sz1, target)