LeetCode 3585. Find Weighted Median Node in Tree
weekly contest 454。
剩五分鐘前才寫完,但是沒膽交答案。
賽後交完四題全對,好像錯過的上分的機會。
題目
https://leetcode.com/problems/find-weighted-median-node-in-tree/description/
解法
設 weight(x, y):從 x 走到 y 的路徑權重和。
對於每個查詢 x, y:
從 x 開始出發往 y 走,找到第一個點 target 滿足:
weight(x, target) >= weight(x, y) / 2
即 weight(x, target) * 2 >= weight(x, y)
看到樹上求距離,又要查詢好幾次,大概就是倍增。
我的倍增模板有三種功能,這次全都用上:
- 求 x, y 的 lca
- 求 x, y 的距離
- 求 x 跳 k 步後的點
從特殊到一般,先考慮最特殊、最單純的情況:
x, y 呈鍊狀,且 y 是根節點。
我們無法直接知道 target 是誰。
但是可以知道從 x 跳 k 步抵達某點 temp 的路徑權重 w,進而知道是否滿足限制。
若跳 k 步可滿足限制,則 k+1 步肯定也滿足限制;若 k 步不滿足限制,則 k-1 步肯定也不滿足限制。
答案具有單調性,可透過二分答案找到第一個滿足限制的步數 k。
每次二分需要從 x 跳 k 步,成本 O(log N)。
共需二分 O(log N) 次,每次查詢複雜度 O(log N * log N)。
再來是比較麻煩的點:
如果起點 x 是 y 的祖先節點、甚至兩者位於不同的子樹怎麼辦?
例如範例二:根節 0,左右節點 1 和 2。
倍增只能從子節點往上跳,如果查詢 x = 1, y = 2,從 1 跳到 0 就卡住了,沒辦法拐彎繼續跳到 2。
答案很簡單,只是我腦子沒想通卡好久。
拐彎點就是 lca!!
設 x 到 lca 有 x_cnt 步。
設 y 到 lca 有 y_cnt 步。
討論 k 的大小:
- k <= x_cnt,就是普通的跳 k 次
- k > x_cnt,則先從 x 跳到 lca,然後從 lca 再跳 k - x_cnt 次
但還是沒解決怎麼從 lca 往 y 子樹的方向跳?
其實 lca 往下跳 need 次,等價於 y 往上跳 y_cnt - need 次。
兩邊跳越的路徑和加起來即可。
時間複雜度 O(Q * log N * log N)。
空間複雜度 O(N log N)。
class Solution:
def findMedian(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
bl = TreeLCA(edges)
def solve(x, y):
lca = bl.get_LCA(x, y)
tot = bl.get_distance(x, y)
x_cnt = bl.depth[x] - bl.depth[lca]
y_cnt = bl.depth[y] - bl.depth[lca]
# return [path_weight, target]
def jump_k_from_x(k):
if k <= x_cnt:
target = bl.jump_k(x, k)
else:
need = y_cnt - (k-x_cnt)
target = bl.jump_k(y, need)
w = bl.get_distance(x, target)
return [w, target]
# bisect for step k from x
lo = 0
hi = x_cnt + y_cnt
while lo < hi:
mid = (lo+hi) // 2
if jump_k_from_x(mid)[0]*2 < tot:
lo = mid+1
else:
hi = mid
return jump_k_from_x(lo)[1]
return [solve(*q) for q in queries]
class TreeLCA:
def __init__(self, edges):
N = len(edges) + 1 # 有多少點
self.MX = N.bit_length() # 最大跳躍次數取 log
# 建圖
g = [[] for _ in range(N)]
for a, b, w in edges:
g[a].append([b, w])
g[b].append([a, w])
# 建樹 樹上前綴和
self.parent = [-1] * N
self.depth = [0] * N
self.ps = [0] * N
def dfs(i, fa, dep):
self.parent[i] = fa
self.depth[i] = dep
for j, w in g[i]:
if j == fa:
continue
self.ps[j] = self.ps[i] + w
dfs(j, i, dep+1)
dfs(0, -1, 0)
# f[i][jump]: 從 i 跳 2^jump 次的位置
# -1 代表沒有下一個點
self.f = [[-1] * self.MX for _ in range(N)]
# 初始化每個位置跳一次
for i in range(N):
self.f[i][0] = self.parent[i]
# 倍增遞推
for jump in range(1, self.MX):
for i in range(N):
temp = self.f[i][jump-1]
if temp != -1: # 必須存在中繼點
self.f[i][jump] = self.f[temp][jump-1]
def get_LCA(self, x, y):
depth = self.depth
f = self.f
if depth[x] > depth[y]:
x, y = y, x
# 把 y 調整到和 x 相同深度
diff = depth[y] - depth[x]
for jump in range(self.MX):
if diff & (1 << jump):
y = f[y][jump]
# 已經相同
if x == y:
return x
# 否則找最低的非 LCA
for jump in reversed(range(self.MX)):
if f[x][jump] != f[y][jump]:
x = f[x][jump]
y = f[y][jump]
# 再跳一次到 LCA
return f[x][0]
def get_distance(self, x, y):
lca = self.get_LCA(x, y)
return self.ps[x] + self.ps[y] - self.ps[lca]*2
def jump_k(self, x, k):
"""
從 x 跳 k 次
-1 表示不合法
"""
for jump in range(self.MX):
if k & (1 << jump):
x = self.f[x][jump]
if x == -1: # 不能跳
return -1
return x