雙周賽122。資料結構題,本身並不是太難。但是前一題 Q3 太燒腦筋,根本沒時間寫了。

題目

輸入長度 n 的整數陣列 nums,還有兩個正整數 k 和 dist。

一個陣列的成本等於他的第一個元素。例如 [1,2,3] 的成本是 1。

你要將 nums 分割成 k 個的獨立子陣列,且第二個子陣列的開頭索引和第 k 個子陣列的開頭索引的差值必須小於等於 dist。
也就是說,nums 分割成子陣列 nums[0..(i1 - 1)], nums[i1..(i2 - 1)], …, nums[ik-1..(n - 1)],滿足 ik-1 - i1 <= dist。

求分割 k 個陣列的最小總成本。

解法

和原題差不多,只是變成要分割成 k 個子陣列,所以要選 nums[0] 加上額外自選 k-1 個最小的元素。
以下所稱第_個元素都是自選的 k-1 個,並非必選的 nums[0]。

除此之外,第一個元素索引和最後一個元素索引的差必須要在 dist 以內。其實就是一個滑動窗口的概念,在大小為 dist+1 的窗口中選擇 k-1 個最小值。
也就是說,當第一個索引是 left 時,則最後一個索引 right 必須滿足 left < right <= min(left + dist, N-1);
或是當最後一個索引是 right 時,第一個索引 left 必須滿足 max(1, right - dist) < left < right,因為 nums[0] 不可以是自選的。


綜上所述,我們需要枚舉窗口的右邊界 right,同時將 right 作為最後一個索引,然後在窗口內剩下的 dist 個元素中,選 k-2 個最小值。測資很好心的保證k - 2 <= dist,所以一定有答案。
在 k 很大的情況下,每次暴力計算總和肯定會超時,需要其他技巧來維護前 k-2 小元素的和。

我想到的是類似480. sliding window median的作法,維護兩個有序容器 s1, s2,其中 s1 只裝 k-2 個最小值,並維護其總和 cost;其餘的丟到 s2 當候補,等到 s1 中的最小值過期、不足 k-2 個時,才從 s2 中挑最小的進去補。

每次窗口右移,加入新值 x 時,總是先加入 s1:

  • 如果 s1 超過 k-2 個,則刪掉最大的值,丟到 s2 去,並從 cost 扣除
  • 不超過,則維持不變

窗口右移,同時受到 dist 的限制的左邊界也會右移。需要刪掉出界的元素 out:

  • 如果 out 小於等於 s1 的最大值,則代表他一定在 s1 裡面。從 s1 刪除並從 cost 扣除
    這時 s1 會不足 k-2 個:
    • 如果 s2 有候補,就拉 s2 的最小值進去 s1,並增加 cost
    • s2 是空的就算了,反正之後擴展邊界一樣會加到 s1
  • 否則一定在 s2 裡面,從 s2 中刪除

雖然說是維護大小為 dist 的窗口,總共有 N 個,但還要分兩邊保持順序,有序容器每次操作O(log dist)。

時間複雜度 O(N log dist)。
空間複雜度 O(dist)。

from sortedcontainers import SortedList as SL

class Solution:
    def minimumCost(self, nums: List[int], k: int, dist: int) -> int:
        N = len(nums)
        ans = inf
        s1 = SL() # k-2 smallest 
        s2 = SL() # other candidates
        cost = nums[0]
        
        # we need k-1 smallest 
        # enumerate last(k-th) element
        # and keep k-2 other smallest
        for right in range(1, N):
            x = nums[right]
            # enough for total k-1 elements
            # update ans and pop expired element
            if len(s1) == k-2:
                ans = min(ans, cost + x)
                # pop expired
                left = right - dist
                if left >= 1:
                    out = nums[left]
                    if out <= s1[-1]: # in s1
                        s1.remove(out)
                        cost -= out
                        # add cand from s2
                        if s2:
                            cand = s2[0]
                            s2.remove(cand)
                            s1.add(cand)
                            cost += cand
                        
                    else: # in s2
                        s2.remove(out)
                    left += 1
            
            # expand right
            s1.add(x)
            cost += x
            if len(s1) > k-2:
                out = s1[-1]
                s1.remove(out)
                cost -= out
                s2.add(out)
            
        return ans