周賽 402。又是線段樹題,這次我有把樹搞出來,但是題目查詢的地方沒想通,又繞了大遠路去搞 sorted list。
除了思路有點障礙之外,寫得還很醜,真的差點沒寫出來。

題目

若陣列 arr 中的某個元素大於其前後的元素,則稱為峰值

輸入整數陣列 nums 以及二維整數陣列 queries。
你必須執行以下兩種操作:

  • queries[i] = [1, li, ri]
    查詢子陣列 nums[li..ri] 中的峰值數量
  • queries[i] = [2, indexi, vali]
    將 nums[indexi] 改成 vali

回傳一個陣列,依序代表每次查詢的結果。

注意:子陣列中第一個最後一個元素都不是峰值。

解法

陣列中的每個元素只有是峰值不是峰值兩種狀態。
因此可以使用有序容器 sorted list 維護峰值的索引,並搭配二分搜找到區間 [l..r] 之間的個數。
首先遍歷 nums 並將所有峰值加入容器中。


先來看查詢:
題目有說到,子陣列中第一和最後一個元素不是峰值,因此查詢 [l..r] 實際上只要找 [(l+1)..(r-1)] 之間的索引。
先二分找到第一個大於等於 l+1 的索引 i,然後找 最後一個小於等於 r-1 的索引 j,峰值個數為 j-i+1。
為了避免 l+1 > r-1 而造成負值,記得和 0 取最大值。


再來看更新:
想想看,改變 nums[i] 的值會有什麼影響?

  1. 讓 nums[i] 變成峰值 / 不是峰值
  2. 讓 nums[i-1] 變成峰值 / 不是峰值
  3. 讓 nums[+1] 變成峰值 / 不是峰值

因此修改 nums[i] 後,要重新判斷以上三個位置的峰值狀態。
實際上只要處理 峰值=>不是峰值不是峰值=>峰值 兩種變化,依照狀態從有序容器中增刪索引。

時間複雜度 O((N + Q) log N)。
空間複雜度 O(N),答案空間不計入。

from sortedcontainers import SortedList as SL
class Solution:
    def countOfPeaks(self, nums: List[int], queries: List[List[int]]) -> List[int]:
        N = len(nums)
        sl = SL()  # peaks
        is_peak = [False] * N

        for i in range(1, N - 1):
            if nums[i - 1] < nums[i] and nums[i] > nums[i + 1]:
                sl.add(i)
                is_peak[i] = True

        ans = []
        for q in queries:
            if q[0] == 1:
                _, l, r = q
                i = sl.bisect_left(l + 1)
                j = sl.bisect_right(r - 1) - 1
                ans.append(max(0, j - i + 1))
            else:
                _, i, val = q
                nums[i] = val
                for j in range(max(1, i - 1), min(N - 2, i + 1) + 1):
                    to_peak = nums[j - 1] < nums[j] and nums[j] > nums[j + 1]
                    if is_peak[j] and not to_peak:  # remove peak
                        is_peak[j] = False
                        sl.remove(j)
                    elif not is_peak[j] and to_peak:  # add peak
                        is_peak[j] = True
                        sl.add(j)

        return ans

處理峰值變化那塊長的有夠醜,又長又難寫,不小心就寫錯了。
有種技巧可以搞得更簡潔:恢復現場
若受到影響的索引 j 原本是峰值,則先把狀態標記取消,待 nums[i] 更新值後再重算一次。

from sortedcontainers import SortedList as SL
class Solution:
    def countOfPeaks(self, nums: List[int], queries: List[List[int]]) -> List[int]:
        N = len(nums)
        sl = SL()  # peaks

        def update(i):
            if nums[i - 1] < nums[i] and nums[i] > nums[i + 1]:
                sl.add(i)
            
        def reset(i):
            if nums[i - 1] < nums[i] and nums[i] > nums[i + 1]:
                sl.remove(i)
                
        for i in range(1, N - 1):
            if nums[i - 1] < nums[i] and nums[i] > nums[i + 1]:
                update(i)

        ans = []
        for q in queries:
            if q[0] == 1:
                _, l, r = q
                i = sl.bisect_left(l + 1)
                j = sl.bisect_right(r - 1) - 1
                ans.append(max(0, j - i + 1))
            else:
                _, i, val = q
                # reset peak state
                for j in range(max(1, i - 1), min(N - 2, i + 1) + 1):
                    reset(j)
                # update peak state
                nums[i] = val
                for j in range(max(1, i - 1), min(N - 2, i + 1) + 1):
                    update(j)

        return ans

沒有 sorted list 的語言就只能乖乖用 BIT 或是線段樹了。
先看看 BIT 版本。

一樣是把山峰的位置標作 1,透過前綴和查詢區間山峰總數。

class Solution:
    def countOfPeaks(self, nums: List[int], queries: List[List[int]]) -> List[int]:
        N = len(nums)
        bit = BIT(N)
        for i in range(1, N - 1):
            is_peak = int(nums[i - 1] < nums[i] and nums[i] > nums[i + 1])
            bit.set(i, is_peak)

        ans = []
        for q in queries:
            if q[0] == 1:
                _, l, r = q
                if l + 1 <= r - 1:
                    res = bit.query_range(l + 1, r - 1)
                else:
                    res = 0
                ans.append(res)
            else:
                _, i, val = q
                nums[i] = val
                for j in range(max(1, i - 1), min(N - 2, i + 1) + 1):
                    is_peak = int(nums[j - 1] < nums[j]
                                  and nums[j] > nums[j + 1])
                    bit.set(j, is_peak)

        return ans


class BIT:
    """
    tree[0]代表空區間,不可存值,基本情況下只有[1, n-1]可以存值。
    offset為索引偏移量,若設置為1時正好可以對應普通陣列的索引操作。
    """

    def __init__(self, n, offset=1):
        self.offset = offset
        self.tree = [0]*(n+offset)

    def update(self, pos, val):
        """
        將tree[pos]增加val
        """
        i = pos+self.offset
        while i < len(self.tree):
            self.tree[i] += val
            i += i & (-i)

    def query(self, pos):
        """
        查詢[1, pos]的前綴和
        """
        i = pos+self.offset
        res = 0
        while i > 0:
            res += self.tree[i]
            i -= i & (-i)
        return res

    def query_range(self, i, j):
        """
        查詢[i, j]的前綴和
        """
        return self.query(j)-self.query(i-1)

    def set(self, pos, val):
        """
        將tree[pos]設成val
        """
        old = self.query_range(pos, pos)
        diff = val-old
        self.update(pos, diff)

再來是我當初寫一半放棄的線段樹解法。真可惜。

這兩種方法和 sorted list 的最大差別在於:查詢一的 [l..r] 的範圍必須手動判斷邊界。
還有查詢二不必恢復現場,因為每次更新都會用新值自下而上合併,一定會更新成正確的區間值。

時間複雜度 O((N + Q) log N)。
空間複雜度 O(N),答案空間不計入。

class Solution:
    def countOfPeaks(self, nums: List[int], queries: List[List[int]]) -> List[int]:
        N = len(nums)
        seg = SegmentTree(N, nums)
        for i in range(1, N - 1):
            seg.update(1, 0, N - 1, i)

        ans = []
        for q in queries:
            if q[0] == 1:
                _, l, r = q
                if l + 1 <= r - 1:
                    res = seg.query(1, 0, N - 1, l + 1, r - 1)
                else:
                    res = 0
                ans.append(res)
            else:
                _, i, val = q
                nums[i] = val
                for j in range(max(1, i - 1), min(N - 2, i + 1) + 1):
                    seg.update(1, 0, N - 1, j)

        return ans


class SegmentTree:

    def __init__(self, n, nums):
        self.tree = [0]*(n*4)
        self.nums = nums
        self.n = n

    def op(self, a, b):
        """
        任意符合結合律的運算
        """
        return a+b

    def push_up(self, id):
        """
        以左右節點更新當前節點值
        """
        self.tree[id] = self.op(self.tree[id*2], self.tree[id*2+1])

    def query(self, id, L, R, i, j):
        """
        區間查詢
        回傳[i, j]的總和
        """
        if i <= L and R <= j:  # 當前區間目標範圍包含
            return self.tree[id]
        res = 0
        M = (L+R)//2
        if i <= M:
            res = self.op(res, self.query(id*2, L, M, i, j))
        if M+1 <= j:
            res = self.op(res, self.query(id*2+1, M+1, R, i, j))
        return res

    def update(self, id, L, R, i):
        """
        單點更新
        判斷 nums[i] 是否為峰值
        """
        if L == R:  # 當前區間目標範圍包含
            if self.nums[L - 1] < self.nums[L] and self.nums[L] > self.nums[L + 1]:
                self.tree[id] = 1
            else:
                self.tree[id] = 0
            return
        M = (L+R)//2
        if i <= M:
            self.update(id*2, L, M, i)
        else:
            self.update(id*2+1, M+1, R, i)
        self.push_up(id)