LeetCode 2913. Subarrays Distinct Element Sum of Squares I
雙周賽116。既是Q1又是Q4,測資範圍不同,難度大概差了二十倍。
題目
輸入整數陣列nums。
nums子陣列的不同計數定義為:
- 令nums[i..j]為nums的子陣列,其中包含介於[i, j]之間的所有索引對應的元素
- nums[i..j]的不同計數等於nums[i..j]中不同值的數量
求所有子陣列不同計數的平方和。
答案可能很大,先模10^9+7後回傳。
解法
暴力枚舉所有子陣列,集合去重後得到不同計數,平方後加入答案。
時間複雜度O(N^3)。
空間複雜度O(N)。
class Solution:
def sumCounts(self, nums: List[int]) -> int:
N=len(nums)
ans=0
for i in range(N):
for j in range(i,N):
sub=nums[i:j+1]
s=set(sub)
ans+=len(s)**2
return ans
Q4的nums長度高達10^5,大概得找個小於O(N^2)的方法才行。
以下簡稱不同計數為dis。
先試著枚舉nums中的每個索引i,並找出以i為右邊界的所有子陣列sub的dis值,觀察其變化規律:
nums = [1,3,1]
i = 0
sub = [[1]], dis = [1]
i = 1
sub = [[1,3],[3]], dis = [2,1]
i = 2
sub = [[1,3,1],[3,1],[1]], dis = [2,2,1]
答案正是這6個dis值的平方總合,共15。
其實有點像是2262. total appeal of a string這題。
觀察發現,每次加入元素x後,所有不包含x的的子陣列的dis值都會加1。那麼有哪些個子陣列符合?
若x上次出現的索引為j,則左邊界大於等於j+1子陣列都符合,共i-(j+1)+1 = i-j個。
好吧,到目前為止我們知道dis值的變化規律。但本題求的是dis值的平方和啊!!
延續上例:
i = 1
dis1 = 2^2 + 1^2 i = 2
dis2 = 2^2 + 2^2 + 1^2
將兩式子的總變化量記為delta:
delta = dis2 - dis1
delta = (2^2 - 1^2) + (1^2 - 0^2)
每個有變化的dis值d,其變化量為(d+1)^2 - d^2。
d變化量 = (d+1)^2 - d^2
展開 = (d^2 + 2d + 1) - d^2
相消 = 2d + 1
也就是從左邊界大於等於j+1的所有子陣列,其dis平方值會增加2d+1。
繼續用剛才的例子來加幾個元素驗證看看:
nums = [1,3,1]
i = 2
sub = [[1,3,1],[3,1],[1]], dis = [2,2,1]
dis平方總和 = 4+4+1 = 9
加入沒出現過的新元素2,所有子陣列的dis值都增加1,每個平方值都會增加2d+1。
則新的dis平方值應為[4+(22+1),4+(22+1),1+(12+1),0+(02+1)] = [9,9,4,1]。
列出子陣列看看:
nums = [1,3,1,2]
i = 3
sub = [[1,3,1,2],[3,1,2],[1,2],[2]], dis = [3,3,2,1]
dis平方總合 = 9+9+4+1 = 23
到目前為止,方法簡化成:
- 遍歷每個nums[i]=x
- 找到x上次出現的位置j
- 平方總合 += sum( dis[idx]*2+1 FOR ALL j<idx<=i)
- dis[idx] += 1 FOR ALL j<idx<=i
但這樣還只是O(N^2),還要繼續優化。
假設有j-i=cnt個符合的索引。
可以把dis[idx]*2+1拆成兩部分,+1的部分就是一開始用到的j-i。
dis[idx]加總則交給線段樹來做區間查詢。
時間複雜度O(N log N)。
空間複雜度O(N)。
class Solution:
def sumCounts(self, nums: List[int]) -> int:
# 區間查詢
# 回傳[i, j]的總和
def query(id, L, R, i, j):
if i <= L and R <= j: # 當前區間被目標範圍包含
return tree[id]
push_down(id, L, R)
ans = 0
M = (L+R)//2
if i <= M:
ans += query(id*2, L, M, i, j)
if M+1 <= j:
ans += query(id*2+1, M+1, R, i, j)
return ans
# 區間更新
# 對[i, j]每個索引都增加val
def update(id, L, R, i, j, val):
if i <= L and R <= j: # 當前區間被目標範圍包含
tree[id] += (R-L+1)*val
lazy[id] += val # 標記每個位置都加val
return
push_down(id, L, R)
M = (L+R)//2
if i <= M:
update(id*2, L, M, i, j, val)
if M+1 <= j:
update(id*2+1, M+1, R, i, j, val)
push_up(id)
# 將區間懶標加到答案中
# 下推懶標記給左右子樹
def push_down(id, L, R):
M = (L+R)//2
if lazy[id]:
lazy[id*2] += lazy[id]
tree[id*2] += lazy[id]*(M-L+1)
lazy[id*2+1] += lazy[id]
tree[id*2+1] += lazy[id]*(R-(M+1)+1)
lazy[id] = 0
# 以左右子樹更新答案
def push_up(id):
tree[id] = tree[id*2]+tree[id*2+1]
MOD=10**9+7
N = len(nums)
tree = [0]*(N*4)
lazy = [0]*(N*4)
last={}
ans=0
sub=0
for i,x in enumerate(nums):
j=last[x] if x in last else -1 # last position we saw "x"
# dis[j+1, i] will be increase
# delta = (d+1)^2 - d^2 = 2d + 1
sub+=i-j # (i-j)*1
sub+=query(1,0,N-1,j+1,i)*2 # sum(dis[j+1, i])*2
ans=(ans+sub)%MOD
update(1,0,N-1,j+1,i,1) # dis[j+1, i] increased by 1
last[x]=i
return ans