LeetCode 3213. Construct String with Minimum Cost
周賽 405。
這題也是很神秘,測資範圍 N = 5e4,依我經驗一看就覺得 python 寫很容易出事。
一般來說測資超過 1e4 之後,O(N^2) 的做法都會超時。
但因為少了最極端的測資,不少人交 O(N^2) 答案竟然過了,甚至賽後看到官方提示也是叫人家用這種作法。
如果說本來就預期 O(N^2) 解,那就是測資範圍設錯,誤導作題者。但是 8 分難度好像又配不上。 如果說測資強度太差,有些人交了 O(N sqrt(N)) 正確答案卻又超時,真的是魔法遊戲。
題目
輸入字串 target,字串陣列 words,還有整數陣列 costs。兩個陣列的長度都相同。
最初存在一個空字串 s。
你可以執行以下操作任意次:
- 選擇 [0, words.length - 1] 之間的索引
- 將 words[i] 加入 s 後方
- 成本增加 costs[i]
求使得 s 等於 target 的最小成本。若不可能則回傳 -1。
解法
在 s 等於 target 之前,我們必須重複決定選哪個 words[i] 來加入。
根據不同的選擇方式,有可能構成相同的結果,故考慮 dp。
定義 dp(i):構造子字串 target[i..N-1] 所需的最小成本。
轉移:dp(i) = max(dp(j + 1) + cost) for ALL target[i..j] in words
base:當 i = N 時,字串構造完畢,回傳 0。
每個 dp(i) 對應 N 個子字串 target[i..j],且每次生成子字串都要 O(N) 時間,因此時間複雜度 O(N^3)。
為了加快字串匹配的速度,先把 words 中所有字串加入字典樹中。
之後枚舉 i 開始的子陣列,每次加入一個新的字元只需要 O(1) 時間匹配。
並且,若樹中不存在 target[i..j] 的前綴,則可以直接剪枝。
時間複雜度 O(N^2 + L),其中 L = sum(words[i].length)。
空間複雜度 O(N + L)。
class Solution:
def minimumCost(self, target: str, words: List[str], costs: List[int]) -> int:
N = len(target)
trie = Trie()
for w, c in zip(words, costs):
trie.add(w, c)
@cache
def dp(i):
if i == N:
return 0
res = inf
curr = trie.root
for j in range(i, N):
c = target[j]
if c not in curr.child:
break
curr = curr.child[c]
res = min(res, curr.val + dp(j + 1))
return res
ans = dp(0)
if ans == inf:
return -1
return ans
class TrieNode:
def __init__(self) -> None:
self.child = defaultdict(TrieNode)
self.val = inf
class Trie:
def __init__(self):
self.root = TrieNode()
def add(self, s, val) -> None:
curr = self.root
for c in s:
curr = curr.child[c]
curr.val = min(curr.val, val)
注意到測資範圍寫著:
The total sum of words[i].length is less than or equal to 5 * 10^4.
總長度 L 限制了 words[i] 所可能出現的長度。
最差情況下 words[i] 長度有 1,2,3,…n,求等差數列和公式 = n *(n + 1) / 2。
滿足 n * (n + 1) / 2 <= L,可得最後一項長度 n 大約為 sqrt(L)。
之前我們枚舉 target[i..j] 時,大概是 O(N) 時間。其實只需要枚舉有出現過的 words[i] 長度,只需要 O(sqrt(L))。
因此 dp 部份的正確時間複雜度是 O(N * sqrt(L))。
但字典樹會一直走到樹的葉節點為止,在 target = “aaa…aaa”, words = [“a”,”aaa……aaa”] 的極端情況下,複雜度同樣會上升到 O(N^2),需要想想其他字串匹配的替代方案。
正確方式是使用 rolling hash 做字串雜湊,預處理所有 words[i] 及 target,即可用 O(1) 的時間查詢子字串是否存在。
時間複雜度 O(N * sqrt(L)),其中 L = sum(words[i].length)。
空間複雜度 O(N)。
MOD = 1000015279
BASE = 87
class Solution:
def minimumCost(self, target: str, words: List[str], costs: List[int]) -> int:
N = len(target)
# hash target for substrings
h_target = [0]
base_pow = [1]
for c in target:
h_target.append((h_target[-1] * BASE + ord(c)) % MOD)
base_pow.append((base_pow[-1] * BASE) % MOD)
# hash all words
# and collect valid sizes
h_words = defaultdict(lambda: inf)
sizes = set()
for word, cost in zip(words, costs):
sizes.add(len(word))
h = 0
for c in word:
h = (h * BASE + ord(c)) % MOD
h_words[h] = min(h_words[h], cost)
sizes = sorted(sizes)
@cache
def dp(i):
if i == N:
return 0
res = inf
for sz in sizes:
j = i + sz - 1
if j >= N:
break
h_sub = (h_target[j + 1] - h_target[i] * base_pow[sz]) % MOD
if h_sub in h_words:
res = min(res, h_words[h_sub] + dp(j + 1))
return res
ans = dp(0)
if ans == inf:
return -1
return ans
然而很可惜,因為 python 嚴格的時間限制,光是這樣會超時。
還得把 min 改成手寫 if 判斷,再把 dp 改成遞推才能勉強通過。
MOD = 1000015279
BASE = 87
class Solution:
def minimumCost(self, target: str, words: List[str], costs: List[int]) -> int:
N = len(target)
# hash target for substrings
h_target = [0]
base_pow = [1]
for c in target:
h_target.append((h_target[-1] * BASE + ord(c)) % MOD)
base_pow.append((base_pow[-1] * BASE) % MOD)
# hash all words
# and collect valid sizes
h_words = defaultdict(lambda: inf)
sizes = set()
for word, cost in zip(words, costs):
sizes.add(len(word))
h = 0
for c in word:
h = (h * BASE + ord(c)) % MOD
# h_words[h] = min(h_words[h], cost)
if cost < h_words[h]:
h_words[h] = cost
sizes = sorted(sizes)
dp = [0] * (N + 1)
for i in reversed(range(N)):
res = inf
for sz in sizes:
j = i + sz - 1
if j >= N:
break
h_sub = (h_target[j + 1] - h_target[i] * base_pow[sz]) % MOD
if h_sub in h_words:
# res = min(res, h_words[h_sub] + dp(j + 1))
t = h_words[h_sub] + dp[j + 1]
if t < res:
res = t
dp[i] = res
ans = dp[0]
if ans == inf:
return -1
return ans