weekly contest 419。
這題又是 sorted list 專場,難度大降。
可惜我寫出來的時後比賽已經結束了。

題目

輸入長度 n 的整數陣列 nums,還有兩個整數 k 和 x。

陣列的 x-sum 計算方式如下:

  • 統計陣列中所有元素的出現頻率。
  • 只保留頻率最高的前 x 種元素。若兩元素頻率相同,則保留數值較大者。
  • 求結果陣列的和。

注意:若陣列中不同的元素少於 x 種,則 x-sum 等於陣列元素和。

回傳長度為 n - k + 1 的陣列 answer,其中 answer[i] 代表子陣列 nums[i..i+k-1] 的 x-sum。

解法

枚舉子陣列很容易想到滑動窗口優化,只保留大小 k 窗口內的所有元素。
難點在於:如何動態維護前 x 大頻率的元素,並計算其總和?


我們可以透過有序容器 sorted list 來維護各元素的出現頻率次序。
元素頻率的排序是先比較頻率降序,然後比較元素數值降序,因此需要保存數對 (freq[val], val)。
並以變數 tot 紀錄前 x 大元素和。

擴展窗口右端點,增加一個元素 val 時,會使得 freq[val] 加 1。
這時需從容器刪除 (freq[val], val),並加入 (freq[val]+1, val)。
分類討論對 tot 造成的影響:

  • 若修改後 val 不為前 x 大,沒有影響。
  • 若修改後 val 為前 x 大:
    • 修改前 val 就是前 x 大,則使 tot 加 val。
    • 修改前 val 並非前 x 大,則使 tot 加 freq[val] * val。
      並且會將修改前第 x 大的元素 t 踢掉,變成第 x+1 大的元素。因此 tot 要扣除 t 的貢獻。

同理,收縮窗口左端點,刪減一個元素 val 時,會使得 freq[val] 減 1。
這時需從容器刪除 (freq[val], val),並加入 (freq[val]-1, val)。
分類討論對 tot 造成的影響:

  • 若修改前 val 不為前 x 大,沒有影響。
  • 若修改前 val 為前 x 大:
    • 修改後 val 還是前 x 大,則使 tot 減 val。
    • 修改後 val 並非前 x 大,則使 tot 減 freq[val] * val (注意此為修改前的 freq[val])。
      並且會將修改前第 x+1 大的元素 t 變成第 x 大。因此 tot 要加上 t 的貢獻。

將以上邏輯封裝成函數,套用滑動窗口即可。

時間複雜度 O(N log k)。
空間複雜度 O(N)。

from sortedcontainers import SortedList as SL
class Solution:
    def findXSum(self, nums: List[int], k: int, x: int) -> List[int]:
        freq = Counter()
        tot = 0
        sl = SL(key=lambda x:(-x[0], -x[1])) # most freq, big val
        for val in set(nums):
            sl.add([0, val])

        def add(val):
            nonlocal tot
            # remove old
            old_pos = sl.bisect_left([freq[val], val])
            sl.pop(old_pos)
            # add new
            freq[val] += 1
            sl.add([freq[val], val])
            new_pos = sl.bisect_left([freq[val], val])
            # compare position
            if new_pos < x:
                if old_pos < x:
                    tot += val
                else: # old sl[x-1] become sl[x]
                    tot += freq[val] * val
                    t = sl[x]
                    tot -= t[0] * t[1]

        def rmv(val):
            nonlocal tot
            # remove old
            old_pos = sl.bisect_left([freq[val], val])
            sl.pop(old_pos)
            # add new
            freq[val] -= 1
            sl.add([freq[val], val])
            new_pos = sl.bisect_left([freq[val], val])
            # compare position
            if old_pos < x:
                if new_pos < x:
                    tot -= val
                else: # old sl[x] become sl[x-1]
                    tot -= freq[val] * val + val # freq[val] before remove
                    t = sl[x-1]
                    tot += t[0] * t[1]

        ans = []
        left = 0
        for right, val in enumerate(nums):
            add(val)
            if right - left + 1 == k:
                ans.append(tot)
                rmv(nums[left])
                left += 1

        return ans