LeetCode 3367. Maximize Sum of Weights after Edge Removals
weekly contes 425。
個人覺得想變數名稱比做法還難。
題目
有個 n 節點的無向樹,節點編號從 0 到 n - 1。
輸入長度 n - 1 的二維整數陣列 edges,其中 edges[i] = [ui, vi, wi],代表節點 ui 和 vi 之間存在一條權重為 wi 的邊。
你的目標是刪除零或更多的邊,使得:
- 每個節點至多有 k 個連向其他節點的邊。
- 剩餘的邊權總和最大化。
求進行必要的刪除之後,剩餘邊的邊權最大和。
解法
隨便舉個簡單的例子:
edges = [[0,1,1], [1,2,10], [2,3,1]], k = 1
每個節點最多只能連一條邊。
很明顯要留 [1,2,10] 這條,他比其他兩邊更大。
換個例子:
edges = [[0,1,10], [1,2,1], [2,3,10]], k = 1
同樣最多一條邊。
但是要改成留 [0,1,10] 和 [2,3,10] 這兩條。
樹的結構相同,卻沒有固定選法,無法直接判斷要選刪 (或保留) 哪條邊。
既然沒有選擇規律,那就只能枚舉每條邊選或不選。
對 edges 中的每條邊枚舉連或不連,先前邊的不同選法可能會剩下相同的子樹與連接限制,有重疊的子問題,考慮 dp。
節點連接數受限於 k,勢必需要額外的狀態來計數各節點的連接數。
而節點數和邊數的上限 N = 10^5,會產生 10^10 個狀態,明顯不可行。
本題測資是樹,又想考慮 dp,那就是樹狀 dp。
樹就是沒有循環的圖。選擇任意節點做為根,從根出發都能夠完整遍歷整棵樹的所有節點各一次,不重不漏。
利用這個特性,在 dfs 節點的過程中,可以一次枚舉所有子樹連或不連的邊權最大和,並選擇最佳的 k 個子結果更新當前子樹的答案。
dfs 時,參數需要紀錄當前節點 i。為了防止往回走,需紀錄父節點 fa。
本題限制節點的邊數至多為 k0,有兩種情形:
- i 與 fa 不連邊,所以 i 可以找 k0 = k 條邊。
- i 與 fa 連邊,所以 i 可以再找 k0 = k-1 條邊。
以參數 fa_conn=0/1 表示 i 是否與父節點相連。
那麼假設當前節點 i 有兩個子節點,但只能選一個連。
edges = [[i,j1,w1], [i,j2,w2]], k0 = 1
討論子節點 j 連不連的情況:
- 連 j,有 take = w + dp(j, fa_conn=1)
- 不連 j,有 notake = dp(j, fa_conn=0)
連 j1,不連 j2。答案為:
= take1 + notake2
不連 j1,連 j2。答案為:
= notake1 + take2
如何決定選或不選好?討論選 j 會產生的損益。
相似題 2611. mice and cheese。
- 不選 j,得到 notake
- 選 j,得到 take
因此選 j 的損益為 delta = take - notake。
根據損益遞減排序,貪心地選前 k0 大收益的節點連邊,加 take;剩餘的都不連,加 notake。
注意:損益有可能是負數,就是虧錢,別考慮了千萬別選。
最後依照上述內容實作樹狀 dp。
定義 dp(i, fa_conn=0/1):以 i 為根的子樹中,每個節點至多 k 條邊的邊權最大和。fa_conn 為 1 代表與父節點相連。
轉移:dp(i) = sum(前 k0 損益連邊) + sum(其餘不連邊)。
選 0 當根節點,根無父節點,答案入口 dp(0, -1, 0)。
時間複雜度 O(N log N),瓶頸在排序。
空間複雜度 O(N)。
class Solution:
def maximizeSumOfWeights(self, edges: List[List[int]], k: int) -> int:
N = len(edges) + 1
g = [[] for _ in range(N)]
for a, b, w in edges:
g[a].append([b, w])
g[b].append([a, w])
@cache
def dp(i, fa, fa_conn):
cand = []
res = 0
for j, w in g[i]:
if j == fa:
continue
take = dp(j, i, 1) + w
notake = dp(j, i, 0)
if notake >= take: # must no take
res += notake
continue
delta = take - notake
cand.append([take, notake, delta])
# sort by delta, take first k0
k0 = k - fa_conn # only take k-1 when connected to fa
cand.sort(reverse=True, key=itemgetter(2))
for i, x in enumerate(cand):
if i < k0:
res += x[0]
else:
res += x[1]
return res
return dp(0, -1, 0)