周賽 396。小小吐槽一下,答案好像沒必要模 10^9 + 7。

題目

輸入整數字串 nums 和兩個整數 cost1, cost2。
你可以執行以下兩種操作任意次

  • 選擇索引 i,以成本為 cost1 將 nums[i] 增加 1
  • 選擇不同的索引 i, j,以成本為 cost2 將 nums[i], nums[j] 都增加 1

求使得陣列中所有元素相同的最小成本

答案可能很大,先模 10^9 + 7 後回傳。

解法

操作一選一個數,操作二選兩個數。
先考慮特殊情況:

  • N = 1,只能用操作一
  • N = 2,操作二沒意義,也只能用操作一
  • cost1 * 2 <= cost2,只用操作一更省錢

除此之外的一般情況,則優先使用操作二,剩下的才使用操作一。


本題的操作都只能把值增加,要把所有元素都變成目標值 T。
對於 nums 中每個元素 x,各需 T - x 次操作。將各元素所需操作次數記做陣列 diff。

問題等價轉換成:

有 N 種不同顏色的石頭,每個顏色的數量為 diff[i]
每次可以挑兩個不同顏色的石頭一組,最多能挑幾組
組數就是操作二,剩下的石頭就是操作一

設 D = max(diff), S = sum(diff)。分類討論以下情況:

  • 情況一:S - D >= D
    可以選 S / 2 組
    • 若 S 是偶數,剩下 0 個石頭
    • 若 S 是奇數,剩下 1 個石頭
  • 情況二:S - D < D
    可以選 S - D 組,剩下 S - (S - D) * 2 個石頭

按照公式代入 cost,大概會變成這樣子。

def f(D, S): 
    if S - D >= D:
        op1 = S % 2
        op2 = S // 2
    else: # S - D < D
        op1 = S - (S - D) * 2
        op2 = S - D
    return cost1 * op1 + cost2 * op2

雖然會覺得 mx = max(nums) 是 T 的最佳選擇,然而並不是。
在 cost1 > cost2 的情況下,有時候更大的 T 反而更划算。
例如:

nums = [1,14,14,15], cost1 = 2, cost2 = 1
T = 15, diff = [14,1,1,0], D = 14, S = 16
總成本 cost = 2 * 12 + 1 * 2 = 26
T = 18, diff = [17,4,4,3], D = 17, S = 28
總成本 cost = 2 * 8 + 1 * 8 = 24
T = 21, diff = [20,7,7,6], D = 20, S = 40
總成本 cost = 2 * 0 + 1 * 20 = 20

那我怎麼知道 T 最大會到多少?

首先釐清增加 T 是為了改變 S 和 D,至於 T 本身並不重要。
在一般情況下,保證 N 至少為 3。每當 T 增加 1,會使得 D 增加 1,然後 (S - D) 增加 N - 1。

情況一的成本只和 S 的值有關,因此 T 繼續增大必定使總成本上升,沒有必要繼續增加。
但情況二會因為 T 增加逐漸降低成本,最後變成情況一
成本曲線應該是類似 V 或是 U 型圖。最粗糙的判斷方式,是在成本上升之時停止。


為了計算複雜度,還是要大概知道枚舉的上界。
在 N >= 3 的一般情況下,每次增加上界只會使得 D 和 (S - D) 的差減少 1。
而最壞情況下 nums = [1000000, 1000000, 1]:

diff = [0, 0, 999999]
D = 999999, S = 999999, (S - D) = 0
大概需要枚舉到兩倍的 D。
但還要考慮到奇偶數的出現順序不同,搞不好情況二的第二個值更小,還是把上界設更大一些比較保險。

時間複雜度 O(N + base_d),其中 base_d = max(nums) - min(nums)。
空間複雜度 O(1)。

MOD = 10 ** 9 + 7
class Solution:
    def minCostToEqualizeArray(self, nums: List[int], cost1: int, cost2: int) -> int:
        N = len(nums)
        mx, mn = max(nums), min(nums)
        base_d = mx - mn
        base_s = mx * N - sum(nums)
        
        # use cost1 only
        if N <= 2 or cost1 * 2 <= cost2: 
            return cost1 * base_s % MOD
        
        # make all elements to target
        def f(D, S): 
            if S - D >= D:
                op1 = S % 2
                op2 = S // 2
            else: # S - D < D
                op1 = S - (S - D) * 2
                op2 = S - D
            return cost1 * op1 + cost2 * op2
        
        ans = inf
        for d in range(base_d, base_d * 2 + 1):
            s = base_s + (d - base_d) * N
            res = f(d, s)
            if res < ans:
                ans = res
            else: # cost increasing
                break
        
        return ans % MOD