以前只知道一種做法,原來有兩種更快的優化。
趕緊來還債。

題目

https://leetcode.com/problems/rectangle-area-ii/description/

解法

矩形至多 N = 200 個,但是座標值域卻高達 10^9。
先找出可能出現的 x, y 軸座標,進行離散化,依序對應到 0~ 2N-1 的值域。

若去重後的 x 軸座標有 X 個,那麼其構成的區間線段會有 X-1 個;y 軸同理,有 Y-1 個線段。
可視作 (X-1) * (Y-1) 的矩陣 cover

再次遍歷所有矩形,按照離散化後的座標,將對應到的部分標記覆蓋。
最後再遍歷矩陣,若 cover[i][j] 已被覆蓋,則查詢原本對應的座標,將面積加入答案中。

時間複雜度 O(N^3)。
空間複雜度 O(N^2)。

class Solution:
    def rectangleArea(self, rectangles: List[List[int]]) -> int:
        # collect coord
        xs = set()
        ys = set()
        for x1, y1, x2, y2 in rectangles:
            xs.add(x1)
            xs.add(x2)
            ys.add(y1)
            ys.add(y2)

        # discretize
        xs = sorted(xs)
        ys = sorted(ys)
        mp_x = {x: i for i, x in enumerate(xs)}
        mp_y = {y: j for j, y in enumerate(ys)}

        # mark cover
        X = len(xs) - 1
        Y = len(ys) - 1
        cover = [[0] * Y for _ in range(X)]
        for x1, y1, x2, y2 in rectangles:
            for x in range(mp_x[x1], mp_x[x2]):
                for y in range(mp_y[y1], mp_y[y2]):
                    cover[x][y] = 1

        # calc cover area
        ans = 0
        for x in range(X):
            for y in range(Y):
                if cover[x][y]:
                    x_width = xs[x+1] - xs[x]
                    y_height = ys[y+1] - ys[y]
                    ans += x_width * y_height

        return ans % (10 ** 9 + 7)

想像 y 軸有一條掃描線,由下往上移動。
每次移動,統計 x 軸有多少線段被覆蓋。

對於矩形 x1, y1, x2, y2 來說,當掃描線掃到 y1 時,線段 [x1, x2] 從此時開始被覆蓋;
掃到 y2 時,線段 [x1, x2] 的覆蓋結束。

把每個矩形轉換成覆蓋開始 / 結束的事件,以 y 軸排序。
每次 y 軸掃瞄線移動,增加的面積即:

y 軸差值 * x 軸覆蓋長度

我們只需要維護 x 軸被覆蓋的線段,所以只有 x 軸需要離散化。

時間複雜度 O(N^2)。
空間複雜度 O(N)。

class Solution:
    def rectangleArea(self, rectangles: List[List[int]]) -> int:
        # collect coord
        # and turn rect into event
        xs = set()
        ys = set()
        events = []
        for x1, y1, x2, y2 in rectangles:
            xs.add(x1)
            xs.add(x2)
            events.append([y1, x1, x2, 1])
            events.append([y2, x1, x2, -1])

        # discretize
        xs = sorted(xs)
        mp_x = {x: i for i, x in enumerate(xs)}

        # mark cover
        X = len(xs) - 1
        cover = [0] * X 

        # sweep line
        events.sort()
        ans = 0
        for i, (y, x1, x2, val) in enumerate(events):
            if i > 0:
                pre_y = events[i-1][0]
                y_height = y - pre_y
                for j, cnt in enumerate(cover):
                    if cnt > 0:
                        x_width = xs[j+1] - xs[j]
                        ans += x_width * y_height

            for x in range(mp_x[x1], mp_x[x2]):
                cover[x] += val

        return ans % (10 ** 9 + 7)

上述統計 x 軸線段覆蓋次數,是區間修改。不難想到線段樹優化。
難點在於:除了維護覆蓋次數之外,要怎麼維護哪些區間沒被覆蓋


區間修改線段樹的效率優勢在於:如果修改的區間完全包含當前節點區間,則打上懶標記,停止向下遞迴。
但對於本題來說,父節點的覆蓋次數改變時,子節點是否被覆蓋的狀態可能會改變

例如:

原有座標 [0,1,2],兩個線段 [0,1], [1,2]
分成三個節點 [0,2], [0,1], [1,2]
[0,1] 被覆蓋一次,然後 [0,2] 也覆蓋一次
這時 [0,2] 沒被覆蓋的線段是 0

如果 [0,2] 刪除一次,[0,1] 依然被覆蓋,但是 [0,2] 變回沒覆蓋了。
除非繼續向下遞迴檢查,否則節點 [0,2] 沒有辦法知道子節點狀態究竟改變沒有。
但這樣操作退化成 O(N),不如暴力維護。


因此需要稍微改變定義,維護:

  • 最小覆蓋次數,以及
  • 屬於最小覆蓋次數的線段長度

如此一來,就算修改節點的覆蓋次數,也不會改變各線段覆蓋次數的相對關係。

查詢時只需要取根節點,檢查是否被覆蓋即可。
有覆蓋的長度即 x 軸全長減去沒覆蓋的長度。

注意:線段樹節點的區間是離散後的座標,並非原始座標。

時間複雜度 O(N log N)。
空間複雜度 O(N)。

class Solution:
    def rectangleArea(self, rectangles: List[List[int]]) -> int:
        # collect coord
        # and turn rect into event
        xs = set()
        events = []
        for x1, y1, x2, y2 in rectangles:
            xs.add(x1)
            xs.add(x2)
            events.append([y1, x1, x2, 1])
            events.append([y2, x1, x2, -1])

        # discretize
        xs = sorted(xs)
        mp_x = {x: i for i, x in enumerate(xs)}

        # sweep line
        seg = SegmentTree(xs)
        events.sort()
        ans = 0
        for i, (y, x1, x2, val) in enumerate(events):
            if i > 0:
                pre_y = events[i-1][0]
                y_height = y - pre_y
                x_width = xs[-1] - xs[0] - seg.get_uncovered()
                ans += x_width * y_height

            l = mp_x[x1]
            r = mp_x[x2]
            seg.update(1, l, r - 1, val)

        return ans % (10 ** 9 + 7)


class Node:
    def __init__(self):
        self.l = 0
        self.r = 0
        self.min_cnt = 0
        self.min_length = 0
        self.lazy = 0


class SegmentTree:
    def __init__(self, xs):
        N = len(xs) - 1
        self.nodes = [Node() for _ in range(N * 4)]
        self.build(xs, 1, 0, N-1)

    def build(self, xs, id, l, r):
        o = self.nodes[id]
        o.l = l
        o.r = r
        if l == r:
            o.min_length = xs[l+1] - xs[l]
            return

        m = (l + r) // 2
        self.build(xs, id*2, l, m)
        self.build(xs, id*2+1, m+1, r)
        self.push_up(id)

    def push_down(self, id):
        o = self.nodes[id]
        lc = self.nodes[id*2]
        rc = self.nodes[id*2+1]
        if o.lazy:
            lc.lazy += o.lazy
            lc.min_cnt += o.lazy
            rc.lazy += o.lazy
            rc.min_cnt += o.lazy
            o.lazy = 0

    def push_up(self, id):
        o = self.nodes[id]
        lc = self.nodes[id*2]
        rc = self.nodes[id*2+1]
        o.min_cnt = min(lc.min_cnt, rc.min_cnt)
        o.min_length = 0
        if lc.min_cnt == o.min_cnt:
            o.min_length = lc.min_length
        if rc.min_cnt == o.min_cnt:
            o.min_length += rc.min_length

    def update(self, id, i, j, val):
        o = self.nodes[id]
        if i <= o.l and o.r <= j:
            o.min_cnt += val
            o.lazy += val
            return

        m = (o.l + o.r) // 2
        self.push_down(id)
        if i <= m:
            self.update(id*2, i, j, val)
        if m < j:
            self.update(id*2+1, i, j, val)
        self.push_up(id)

    def get_uncovered(self):
        root = self.nodes[1]
        if root.min_cnt > 0:
            return 0
        return root.min_length