LeetCode 3041. Maximize Consecutive Elements in an Array After Modification
雙周賽124。根本沒想到又是 dp,想著二分罰坐一小時。
題目
輸入正整數陣列 nums。
最初,你可以選擇陣列中任意個元素,並將其值加 1。
修改後,你必須選擇一或多個元素,這些元素排序後,是滿足相鄰遞增的。
例如 [3, 4, 5] 滿足,但 [3, 4, 6] 和 [1, 1, 2, 3] 不滿足。
求最多可以選擇幾個元素。
解法
在陣列中選擇任意元素後排列,相當於選擇子序列。
子序列順序不影響答案,而且檢查相鄰遞增也要維持有序。總之先把 nums 排序。
對於 nums 中的每個元素 x,有兩種使用方案:
- 加 1,把 x+1 連接在 x 結尾的子序列上
- 保持不變,把 x 連接在 x-1 結尾的子序列上
我們在乎的是以某元素 x 結尾的最大長度,因此是以值域作為 dp 的狀態,而非普遍的使用的元素索引。
另外還有一個不同點的,大多數 dp 都是從多個來源轉移到一個狀態,稱為填表法;這次是用一個值 x 去更新多個狀態,稱為刷表法。
定義 dp(i):以 i 作為結尾的最大子序列長度。
轉移:dp[i+1] = dp[i];dp[i] = dp[i-1]
需要注意的是,對於元素 x,一定要先更新 dp[x+1] 後才更新 dp[x],否則會得到錯誤的答案。
時間複雜度 O(N log N),瓶頸在於排序。
空間複雜度 O(N)。
class Solution:
def maxSelectedElements(self, nums: List[int]) -> int:
nums.sort()
dp = Counter()
for x in nums:
dp[x+1] = dp[x] + 1
dp[x] = dp[x-1] + 1
return max(dp.values())
其實要用熟悉的填表法也可以做,但是真的要寫一長串。
一樣排序後,試著求出以 nums[i] 結尾的最大長度。
只不過 nums[i] 可以選擇是否加 1,所以需要額外的變數 inc_i 來表示此狀態。
所以實際結尾的元素是 target = nums[i] + inc_i。
我們要在 nums[0..i-1] 之間找到等於 target - 1 元素 nums[j],並把 target 接在後面。
同樣的,nums[j] 也可以增加值,所以對於 nums[j] 和 nums[j] + 1 兩種結尾的子序列都要找。
但相同的 nums[j] 可能存在好幾個,總不可能每個都遍歷。而實際上只要取最後一個就行。
因為陣列是有序的,相同的元素越靠右邊,越有可能使得子序列變長。
實際上相同元素 x 出現第三次之後都沒有用。試想以下例子:
nums = [1,2,2]
nums[0] = 1 結尾 = [1]
nums[0] + 1 = 2 結尾 = [2]
nums[1] + = 2 結尾 = [1,2]
nums[1] + 1 = 3 結尾 = [1,2]
nums[2] + = 2 結尾 = [1,2]
nums[2] + 1 = 3 結尾 = [1,2,3]
相同的元素 x 出現第二次,頂多當作 x+1 來用,把上次 x 結尾的子序列變長一格。
之後在出現根本沒意義。
在有序的陣列中找到特定值,很明顯就是靠二分搜了。
定義dp(i, inc_i):以 nums[i] + inc_i 結尾的最長子序列。
轉移:dp(i, inc_i) = max( dp(j, inc_j) + 1 ) FOR ALL 0 <= j < i AND nums[j] + inc_j + 1 = nums[i] + inc_i
邊界:當 i = 0 時,子序列只有自己一個元素,回傳 1。
時間複雜度 O(N log N)。
空間複雜度 O(N)。
class Solution:
def maxSelectedElements(self, nums: List[int]) -> int:
nums.sort()
N = len(nums)
@cache
def dp(i, inc_i):
if i == 0:
return 1
res = 1
target = nums[i] + inc_i
for inc_j in range(2):
# find last nums[j] + inc_j + 1 < target
# lo = 0
# hi = i - 1
# while lo < hi:
# mid = (lo + hi + 1) // 2
# if nums[mid] + inc_j >= target:
# hi = mid - 1
# else:
# lo = mid
# j = lo
j = bisect_left(nums, target, hi = i, key=lambda x:x+ inc_j) - 1
if nums[j] + inc_j + 1 == target:
res = max(res, dp(j, inc_j) + 1)
return res
ans = 0
for i in range(N):
for inc_i in range(2):
ans = max(ans, dp(i, inc_i))
return ans
改成遞推寫法。
注意:dp[0] 需要特判設成 1,不然就是要檢查二分找出來的索引 j 是否界於合法範圍 [0, i-1],否則會得到錯誤答案。
class Solution:
def maxSelectedElements(self, nums: List[int]) -> int:
nums.sort()
N = len(nums)
dp = [[1, 1] for _ in range(N)]
for i in range(1, N): # dp[0] is base case!!
for inc_i in range(2):
res = 1
target = nums[i] + inc_i
for inc_j in range(2):
j = bisect_left(nums, target, hi = i, key=lambda x:x+ inc_j) - 1
if nums[j] + inc_j + 1 == target:
res = max(res, dp[j][inc_j] + 1)
dp[i][inc_i] = res
ans = 0
for i in range(N):
for inc_i in range(2):
ans = max(ans, dp[i][inc_i])
return ans