LeetCode 2382. Maximum Segment Sum After Removals
雙周賽85。這次拿到了383名,刷新個人最佳紀錄。雖然客觀來說這題偏難,不過剛好對上我的電波,就輕鬆AC了。
題目
輸入兩個長度為n的整數陣列nums和removeQueries。對於第i次查詢,刪除nums中索引為removeQueries[i]的元素,並將nums分成不同段。
段指的是nums的連續正整數序列。段總和是段中每個元素的總和。
回傳長度為n的整數陣列answer,其中answer[i]是查詢完第i次之後的最大段總合。
注意:每個索引只會被刪除一次。
解法
如果要把某個序列切成兩段,有兩個問題:
- 根據下刀的位置不同,要重新計算兩段的總合
- 不知道當前切開的段是不是最大段總和
感覺太麻煩了,我們試著逆向處理,變成在idx索引處加上某數字,使得左右兩邊的段合併。因為nums[i]只會出現正數,所以段總合永遠不可能變小,只需要以新產生的段來更新最大段總和。
答案陣列ans長度為N,但是N次查詢之後每個元素都被刪除了,ans[N-1]必定為0,所以要從倒數第二格開始往前填答案;同理,removeQueries[0]可以不必處理。
接下來開始寫程式,維護一個叫做sl的SortedList,用來有序地儲存各個段。還有變數mx紀錄當前最大段總合。
遍歷N-1~1中的第i個查詢,執行以下步驟:
- 用二分搜找到插入點j
- 檢查左右兩方的段是否能夠合併,若可合併則更新左右邊界及總合,刪除已合併的段,並插入新的段
- 更新最大段總合mx,填入ans[i-1]中
總共需要處理N-1個查詢,每次查詢最多需要4次二分搜(找插入點+刪除合併段2次+插入新段),整體複雜度為O(N log N)。
from sortedcontainers import SortedList
class Solution:
def maximumSegmentSum(self, nums: List[int], removeQueries: List[int]) -> List[int]:
N=len(nums)
ans=[0]*N
sl=SortedList() # [l,r,sum]
i=N-1
mx=0
while i>0:
l=r=idx=removeQueries[i]
sm=nums[idx]
j=sl.bisect_left([l])
to_remove=[]
# merge left
if j>0 and sl[j-1][1]==idx-1:
l=sl[j-1][0]
sm+=sl[j-1][2]
to_remove.append(sl[j-1])
# merge right
if j<len(sl) and sl[j][0]==idx+1:
r=sl[j][1]
sm+=sl[j][2]
to_remove.append(sl[j])
# remove merged
for x in to_remove:
sl.remove(x)
sl.add([l,r,sm])
mx=max(mx,sm)
i-=1
ans[i]=mx
return ans
要照順序刪除其實也是可以,只是二分搜找區間調整起來很麻煩,一個不小心就出錯。
至於最大區間總和的部分可以搭配前綴和以O(1)計算新的段總合,放到sortedlist裡面就可以快速得到最大值。
from sortedcontainers import SortedList
class Solution:
def maximumSegmentSum(self, nums: List[int], removeQueries: List[int]) -> List[int]:
N=len(nums)
psum=[0]+list(accumulate(nums))
ans=[0]*N
segment_sum=SortedList()
segment_sum.add(psum[-1])
sl=SortedList() # [l,r,sum]
sl.add([0,N-1,psum[-1]])
for i in range(N-1):
idx=removeQueries[i]
j=sl.bisect_left([idx+1])-1
l,r,sm=sl[j]
# remove old
sl.pop(j)
segment_sum.remove(sm)
# add left
if l<idx:
sm=psum[idx]-psum[l]
sl.add([l,idx-1,sm])
segment_sum.add(sm)
# add right
if r>idx:
sm=psum[r+1]-psum[idx+1]
sl.add([idx+1,r,sm])
segment_sum.add(sm)
ans[i]=segment_sum[-1]
return ans
同樣是逆向操作,可以使用併查集,將時間複雜度降低至O(N)。
對於每個要插入的點idx,初始化段總合為nums[i],檢查左右是否存在,若存在則將其root指向idx,並將總和加到idx上。
class UnionFind:
def __init__(self):
self.parent = {}
self.sm = {}
def union(self, x, y):
px = self.find(x)
py = self.find(y)
if px != py:
self.sm[py]+=self.sm[px]
self.parent[px] = py
def find(self, x):
if x != self.parent[x]:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
class Solution:
def maximumSegmentSum(self, nums: List[int], removeQueries: List[int]) -> List[int]:
N=len(nums)
ans=[0]*N
uf=UnionFind()
mx=0
i=N-1
while i>0:
idx=removeQueries[i]
uf.parent[idx]=idx
uf.sm[idx]=nums[idx]
# merge left
if idx-1 in uf.parent:
uf.union(idx-1,idx)
# merge right
if idx+1 in uf.parent:
uf.union(idx+1,idx)
mx=max(mx,uf.sm[idx])
i-=1
ans[i]=mx
return ans