LeetCode 3515. Shortest Path in a Weighted Tree
biweekly contest 154。
好久不見的 dfs 時間戳,大概超過一年沒出現。
本來自己沒線索,但是看到大神提示就知道怎麼做了。
題目
https://leetcode.com/problems/shortest-path-in-a-weighted-tree/description/
解法
難點在如何修改路徑權重。
從特殊到一般,考慮鍊狀樹修改權重的時候會發生什麼事?
設 o1 為根節點。
o1 -> o2 - > o3 -> o4 -> o5 …
若把邊 (o2, o3) 的權重增加 delta:
- 從 o1 到 o1, o2 的路徑和不影響
- 從 o1 到 o2, o3, o4, o5,.. 的權重都會增加 delta
可見邊下方子樹中所有節點的的路徑和都會增加 delta。
如何知道子樹中有那些節點?
利用 dfs 先進後出的特性,在遞迴過程中給訪問到的節點標記進入時間,即 dfs 時間戳 (可以理解成第幾個被訪問)。
同樣地,在退出遞迴時標記離開時間。
則對於節點 i 來說,位於連續區間 [tin[i]..tout[i]] 即子樹所有節點的時間戳。
若通往 i 的的邊權增加 delta,則只需要把區間 [tin[i]..tout[i]] 都增加 delta。
問題轉換成區間修改和單點查詢,可以用線段樹或是樹狀陣列。
此處選用線段樹。
最後剩下小細節,邊權是修改成新值,需要自己維護舊值,並計算增量 delta。
然後 (u, v) 看不出來誰位於下方,需要利用時間戳 tin[u], tin[v] 判斷。
注意:不要把節點的原編號和時間戳搞混!!很重要!!
時間複雜度 O((N + Q) log N)。
空間複雜度 O(N)。
class Solution:
def treeQueries(self, n: int, edges: List[List[int]], queries: List[List[int]]) -> List[int]:
seg = SegmentTree(n+5)
g = [[] for _ in range(n+1)]
edge_weight = Counter()
for a, b, w in edges:
g[a].append([b, w])
g[b].append([a, w])
edge_weight[(a, b)] = w
timestamp = 0
tin = [0] * (n+1)
tout = [0] * (n+1)
def dfs(i, fa, sm):
nonlocal timestamp
timestamp += 1
tin[i] = timestamp
seg.update(1, 1, n, timestamp, timestamp, sm)
for j, w in g[i]:
if j == fa:
continue
dfs(j, i, sm+w)
tout[i] = timestamp
dfs(1, -1, 0)
ans = []
for q in queries:
# query path sum of [1, x]
if q[0] == 2:
x = q[1]
sm = seg.query(1, 1, n, tin[x], tin[x])
ans.append(sm)
continue
# update weight of edge (u, v)
_, u, v, new_w = q
old_w = edge_weight[(u, v)]
delta = new_w - old_w
edge_weight[(u, v)] = new_w
# find which is son
if tin[u] < tin[v]:
son = v
else:
son = u
# apply update to subtree
seg.update(1, 1, n, tin[son], tout[son], delta)
return ans
class SegmentTree:
def __init__(self, n):
self.tree = [0]*(n*4)
self.lazy = [0]*(n*4)
def op(self, a, b):
"""
任意符合結合律的運算
"""
return a+b
def push_down(self, id, L, R, M):
"""
將區間懶標加到答案中
下推懶標記給左右子樹
"""
if self.lazy[id]:
self.tree[id*2] += self.lazy[id]*(M-L+1)
self.lazy[id*2] += self.lazy[id]
self.tree[id*2+1] += self.lazy[id]*(R-M)
self.lazy[id*2+1] += self.lazy[id]
self.lazy[id] = 0
def push_up(self, id):
"""
以左右節點更新當前節點值
"""
self.tree[id] = self.op(self.tree[id*2], self.tree[id*2+1])
def query(self, id, L, R, i, j):
"""
區間查詢
回傳[i, j]的總和
"""
if i <= L and R <= j: # 當前區間目標範圍包含
return self.tree[id]
M = (L+R)//2
self.push_down(id, L, R, M)
res = 0
if i <= M:
res = self.op(res, self.query(id*2, L, M, i, j))
if M+1 <= j:
res = self.op(res, self.query(id*2+1, M+1, R, i, j))
return res
def update(self, id, L, R, i, j, val):
"""
區間更新
對[i, j]每個索引都增加val
"""
if i <= L and R <= j: # 當前區間目標範圍包含
self.tree[id] += val * (R - L + 1)
self.lazy[id] += val
return
M = (L+R)//2
self.push_down(id, L, R, M)
if i <= M:
self.update(id*2, L, M, i, j, val)
if M < j:
self.update(id*2+1, M+1, R, i, j, val)
self.push_up(id)