LeetCode 2846. Minimum Edge Weight Equilibrium Queries in a Tree
周賽361。上週才考過倍增,這週馬上就考進階用法,真變態。
雖說是進階版,但LCA倍增其實算是競賽的常見題目,網路上隨便都找得到模板可以套用。可能因此通過人數比上次還多。
題目
有一個n節點的無向樹,節點編號分別為0到n-1。
輸入二維整數陣列edges,其中edges[i] = [ui, vi, wi],代表節點ui和vi之間存在一條權重為wi邊。
另外還輸入長度m的二維整數陣列queries,其中queries[i] = [ai, bi]。
對於每個查詢queries[i],要求出使得ai到bi路徑上的權重相等所需最少操作次數。
每次操作,你可以將任意邊上的權重改變成任意值。
注意:
- 每次操作都是獨立的。意味著每次操作前,所有邊的權重都會恢復成初始值
- ai到bi的路徑是由不同的節點序列所組成,從ai開始,bi結束。且序列中相鄰的節點都共享一條邊
回傳長度m的陣列answer,其中answer[i]代表第i次查詢的答案。
解法
雖然說是一棵樹,但沒指定根節點,方便起見都把節點0當作根。
要使得路徑上的權重相等,又要操作次數最小,那就只能把所有權重都改成出現次數最多的那個。
但是怎麼求a和b之間的路徑?先找他們的LCA(最近公共祖先),從LCA到a的路徑加上LCA到b的路徑,就是完整的路徑。
傳統的LCA算法是先使ab深度相等,然後兩者同時向父節點移動,直到ab相等,當前節點就是LCA。
但是本題中節點數量高達10^4,最壞形況下是linked list,每次找LCA都要10^4次移動。查詢也是10^4量級,必須要想辦法優化LCA的算法。
找LCA分成兩個步驟:平衡深度、同時上移。
我們知道a和b深度的差為diff,將diff分解成數個2^j次移動,就是基本款的倍增應用。
但是又怎麼知道要跳多少步才到LCA?這時又要導入二分思想:
- 跳x步後,a和b相等,代表已經找到LCA或是LCA的祖先。可以保證LCA在當前節點,或是在下方
- 若a和b不相等,則代表還沒找到LCA。可以保證LCA一定在上方
只要找到深度最低,且不是a!=b的位置,再往上一步,就是LCA了。有點類似bisect_right或是upper_bound之後再減1的概念。
接著剛才講的,如果x步會跳到祖先節點,那目標最多只需要x-1步。所以從最大的2^j開始往下檢查,如果跳2^j會相同,就不跳;不會相同就跳。
最後會停在LCA的下方,再往上跳一步就大功告成。
注意:在a或b本身就是LCA的情況下,倍增平衡完深度就可以直接回傳。
最後處理每個查詢[a, b]:
- 先找到ab的lca
- 求a到lca,再從lca到b的路徑
- 找到最高的權重出現次數max_w,把總邊數tot扣掉max_w就是答案
預處理倍增O(n log n),每次查詢O(log n),整體時間複雜度O( n log n + Q log n )。
空間複雜度O(n log n)。
class Solution:
def minOperationsQueries(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
m=n.bit_length()
g=[[] for _ in range(n)]
for a,b,w in edges:
g[a].append([b,w-1])
g[b].append([a,w-1])
pos=[[-1]*m for _ in range(n)]
dep=[0]*n
weight=[None]*n
weight[0]=[0]*26
# build depth and weights
def dfs(i,fa,d):
pos[i][0]=fa
dep[i]=d
for j,w in g[i]:
if j==fa:
continue
weight[j]=weight[i][:]
weight[j][w]+=1
dfs(j,i,d+1)
dfs(0,-1,0)
# build binary lifting
for j in range(1,m):
for i in range(n):
fa=pos[i][j-1]
if fa!=-1:
pos[i][j]=pos[fa][j-1]
def get_lca(x,y):
if dep[x]>dep[y]:
x,y=y,x
# make x and y same depth
# by move diff steps from y
diff=dep[y]-dep[x]
for j in range(m):
if diff&(1<<j):
y=pos[y][j]
# found lca
if x==y:
return x
# find lowest non-lca
for j in reversed(range(m)):
if pos[x][j]!=pos[y][j]:
x=pos[x][j]
y=pos[y][j]
# now x and y are below lca 1 step
# one more step to lca
return pos[x][0]
ans=[]
for a,b in queries:
lca=get_lca(a,b)
tot=dep[a]+dep[b]-dep[lca]*2
w_sum=[x1+x2-x3*2 for x1,x2,x3 in zip(weight[a],weight[b],weight[lca])]
ans.append(tot-max(w_sum))
return ans