weekly contest 455。

題目

https://leetcode.com/problems/minimum-time-to-transport-all-individuals/description/

解法

有 n 個人要划船過河,船只有一艘且一次最多裝 k 個人。若還有人沒過,需要有 1 人從對岸把船開回來。
每次過河的時間按照船上的人的最低速度乘上流速倍率 mul。


看到 n <= 12,直覺想到 bitmask 表示剩餘人口。
過河人口與其最大值也可以用 bitmask 優化。

本以為是狀壓 dp,但是人可以來回過河,剩餘人口會出現循環。河流狀態也是循環,無法轉換成更小的子問題,故不適用 dp。


已知的狀態有:剩餘人口、船的位置、流速倍率。分別以 ppl, dir, speed 表示。

在三個狀態相同時,應避免重複計算,又要求最小時間。
時間只增不減,可用 dijkstra 最短路,將三個狀態合併看成一個點。

剩下就是分類討論:

  • 船在出發點,從剩餘人口枚舉不超過 k 個人過河
  • 船在終點,從抵達以口枚舉正好 1 個人把船開回去

複雜度不會算。

class Solution:
    def minTime(self, n: int, k: int, m: int, time: List[int], mul: List[float]) -> float:
        FULL = (1 << n) - 1

        mask_time = [0] * (1 << n)
        for mask in range(1 << n):
            for i in range(n):
                if mask & (1 << i):
                    mask_time[mask] = max(mask_time[mask], time[i])

        dist = [[[inf]*m for _ in range(2)] for _ in range(1 << n)]
        h = []
        heappush(h, [0, FULL, 0, 0])  # cost, ppl, dir=0/1, spd
        dist[FULL][0][0] = 0
        while h:
            cost, ppl, dir, spd = heappop(h)
            if cost > dist[ppl][dir][spd]:
                continue

            if dir == 0:  # at start
                for mask in range(1, 1 << n):
                    if mask.bit_count() <= k and mask | ppl == ppl:  # at most k ppl
                        d = mask_time[mask]*mul[spd]
                        new_cost = cost + d
                        new_ppl = ppl ^ mask
                        new_spd = (spd+floor(d)) % m
                        if new_cost < dist[new_ppl][1][new_spd]:
                            dist[new_ppl][1][new_spd] = new_cost
                            heappush(h, [new_cost, new_ppl, 1, new_spd])
            else:  # at end
                if ppl == 0:  # all arrived
                    return cost
                for i in range(n):
                    mask = 1 << i
                    if mask & ppl == 0:
                        d = mask_time[mask]*mul[spd]
                        new_cost = cost + d
                        new_ppl = ppl ^ mask
                        new_spd = (spd+floor(d)) % m
                        if new_cost < dist[new_ppl][0][new_spd]:
                            dist[new_ppl][0][new_spd] = new_cost
                            heappush(h, [new_cost, new_ppl, 0, new_spd])

        return -1

至多 k 個人的子集可以先預處理,會加速很多。
添加至 heap 的邏輯也可以封裝起來。

class Solution:
    def minTime(self, n: int, k: int, m: int, time: List[int], mul: List[float]) -> float:
        FULL = (1 << n) - 1

        mask_time = [0] * (1 << n)
        for mask in range(1 << n):
            for i in range(n):
                if mask & (1 << i):
                    mask_time[mask] = max(mask_time[mask], time[i])

        g = [[] for _ in range(1 << n)]
        for ppl in range(1 << n):
            for sub in range(1, 1 << n):
                if sub.bit_count() <= k and sub | ppl == ppl:
                    g[ppl].append(sub)

        dist = [[[inf]*m for _ in range(2)] for _ in range(1 << n)]
        h = []

        def push(cost, ppl, dir, mul_state):
            if cost < dist[ppl][dir][mul_state]:
                dist[ppl][dir][mul_state] = cost
                heappush(h, [cost, ppl, dir, mul_state])

        push(0, FULL, 0, 0)

        while h:
            cost, ppl, dir, spd = heappop(h)
            if cost > dist[ppl][dir][spd]:
                continue

            if dir == 0:  # at start
                for mask in g[ppl]:
                    d = mask_time[mask]*mul[spd]
                    new_spd = (spd+floor(d)) % m
                    push(cost+d, ppl ^ mask, dir ^ 1, new_spd)
            else:  # at end
                if ppl == 0:  # all arrived
                    return cost
                for i in range(n):
                    mask = 1 << i
                    if mask & ppl == 0:
                        d = mask_time[mask]*mul[spd]
                        new_spd = (spd+floor(d)) % m
                        push(cost+d, ppl ^ mask, dir ^ 1, new_spd)

        return -1