LeetCode 3007. Maximum Number That Sum of the Prices Is Less Than or Equal to K
周賽380。
題目
輸入整數 k 和 x。
s 是一個整數 num 的二進位表示,索引從 1 開始。 num 的價格是滿足 i % x == 0 且 s[i] 是 1 的個數。
求能夠滿足 1~num 之間所有數的價格總和小於等於 k 的最大 num 值。
注意:
- 二進位是從右向左數。例如 s = 11000, s[4] 是 1,而 s[2] 是 0
解法
題目要求 1~num 的總價小於等於 k。
價格不為負數,總價其實具有單調性,會隨著 num 單調遞增。
若 1~x 的總價超過 k,則答案不可能是 x 以上的值;反之,若 1~x 總價不足 k ,則答案至少會是 x。
因此我們可以透過二分答案來找到適當的 num。
現在的問題變成:怎麼求 1~num 之間能被 x 整除的 1 位元有幾個?
看到這個 1~num 之間,就想到我們的老朋友數位dp,基本上就和 233. number of digit one 作法差不多,只是要計算一下當前是第幾個 bit。
計算複雜度之前,還有另一個問題:num 的可能上界為多少?
我也不知道怎麼算,但是測出來大概是 10^15,跟 k 的上限差不多。大概需要二分 O(log k)次,每次二分都要做一次數位dp。
根據 num 的最大值 k,其二進位長度同為 O(log k)。
同時,一個二進位表示中也最多擁有 O(log k) 個 1。
每次數位dp 總共有 O(log k)^2 個狀態,每個狀態轉移一次。
時間複雜度 O((log k)^3)。
空間複雜度 O((log k)^2)。
class Solution:
def findMaximumNumber(self, k: int, x: int) -> int:
def ok(num):
s = bin(num)[2:]
N = len(s)
@cache
def dp(i, is_limit, cnt1):
if i == N:
return cnt1
bit = N - 1 - i + 1
res = 0
down = 0
up = 1 if not is_limit else int(s[i])
for j in range(down, up + 1):
new_cnt1 = cnt1 + 1 if bit % x == 0 and j == 1 else cnt1
new_limit = is_limit and j == up
res += dp(i + 1, new_limit, new_cnt1)
return res
return dp(0, True, 0) <= k
lo = 1
hi = 10**15
while lo < hi:
mid = (lo + hi + 1) // 2
if not ok(mid):
hi = mid - 1
else:
lo = mid
return lo
這種位元運算類型的題目,通常可以把每個位元分開處理計算,稱作拆位。
先照著範例2 的答案,列出 1~9 的二進位看看:
| num | binary |
| — | —— |
| 0 | 0000 |
| 1 | 0001 |
| 2 | 0010 |
| 3 | 0011 |
| 4 | 0100 |
| 5 | 0101 |
| 6 | 0110 |
| 7 | 0111 |
| 8 | 1000 |
| 9 | 1001 |
發現這些數的第 1 個位元,是由 01 01.. 的規律循環組成。
第 2 個位元是 0011 0011..;第 3 個位元是 000011111..。
可得結論:第 i 個位元由 2^i 個數為一次循環。其中,前半 2^(i-1) 個位元都是 0,後半 2^(i-1) 個位元都是 1。
注意:循環是從 num = 0 開始,總共有 num + 1 個數。
將這 num + 1 個數字分組,每組有 2^i 個數,看能循環幾次。每次循環會貢獻 2^(i-1) 個 1位元。
至於剩下沒分到組的數字,忽略掉前半的 0位元,只取後半的 1位元。
最後判斷 1位元的總數是否小於等於 k。
時間複雜度 O((log k)^2)。
空間複雜度 O(1)。
class Solution:
def findMaximumNumber(self, k: int, x: int) -> int:
def ok(num):
cnt1 = 0
for i in range(1, num.bit_length() + 1):
if i % x != 0:
continue
# "rep" elements for a repetition
# first half are 0s
# and last half are 1s
rep = 1 << i
# [0 ~ num] are "num + 1" elements, group them of "rep"
# there are "rep_cnt" full repetitions and "remain" elements alone
rep_cnt, remain = divmod(num + 1, rep)
cnt1 += rep_cnt * (rep // 2) # full rep
cnt1 += max(0, remain - (rep // 2)) # only last half are 1s
return cnt1 <= k
lo = 1
hi = 10**15
while lo < hi:
mid = (lo + hi + 1) // 2
if not ok(mid):
hi = mid - 1
else:
lo = mid
return lo