周賽366。完全沒想到是dp,而且竟然有三種dp作法,真的是好題。

題目

輸入兩個二進位字串s1和s2,長度都是n,還有一個正整數x。

你可以對s1執行以下操作任意次:

  • 選擇兩個索引i和j,並反轉s1[i]和s1[j]。操作成本為x
  • 選擇小於n-1的索引i,並反轉s1[i]和s1[i+1]。操作成本為1

求使得s1和s2相等的最小成本,若不可能相等則回傳-1。

注意:反轉指的是將0變成1,或是1變成0。

解法

我們只要考慮s1[i]和s2[i]不相同的索引i,原本相同的話再去反轉也沒意義。

選擇索引i, j,操作後的變化只有四種情形:00變11、11變00、01變10、10變01。
字串中1的的數量一定是偶數增減,若不同的索引有奇數個,則不存在答案。

試想以下例子:

s1 = 100, s2 = 001
如果使用第一種操作,直接反轉s1[0]和s1[2],成本x
也可以用第二種操作,先反轉s1[0]和s1[1],s1 = 010
然後再反轉s1[1]和s1[2],s1 = 001,成本2

發現想要同時反轉i, j兩個索引,可以選擇成本較低的方法,也就是min(x, j-i)。
那要怎樣配對才能保證成本最小?

如果循序倆倆配對,那麼碰到s1 = 10011001, s2 = 00000000 這種情形就會算錯。
試著往dp去考慮。

定義dp(l,r):需要反轉的索引陣列idx中,反轉子陣列idx[l, r]的最小成本。
枚舉與idx[l]配對的索引idx[i],並計算反轉idx[l+r, l]和idx[l+1, r]兩個子字串。
轉移方程式:dp(l,r) = min( cost(idx[l],idx[i]) + dp(l+1, i-1) + dp(i+1, r) ) FOR ALL l<i<r
base cases:若l>r代表反轉結束,回傳0;若r-i+1為奇數,則不可能有答案,回傳inf。

狀態共有N^2個,每個狀態需要轉移N次。
時間複雜度O(N^3)。
空間複雜度O(N^2)。

class Solution:
    def minOperations(self, s1: str, s2: str, x: int) -> int:
        idx=[]
        for i,c in enumerate(s1):
            if c!=s2[i]:
                idx.append(i)
        
        @cache
        def dp(l,r):
            if (l-r+1)%2==1:
                return inf
            if l>r:
                return 0
            res=inf
            for i in range(l+1,r+1):
                cost=min(x,idx[i]-idx[l])
                res=min(res,cost+dp(l+1,i-1)+dp(i+1,r))
            return res
        
        ans=dp(0,len(idx)-1)
        
        if ans==inf:
            return -1
        
        return ans

按照最初說的,反轉的個數必須是偶數,乾脆在一開始就過濾。
dp轉移時也只枚舉奇偶性不同的索引,省略一堆無效的計算。

複雜度不變,但是常數小很多。

class Solution:
    def minOperations(self, s1: str, s2: str, x: int) -> int:
        idx=[i for i,c in enumerate(s1) if c!=s2[i]]
        
        if len(idx)%2==1:
            return -1
        
        @cache
        def dp(l,r):
            if l>r:
                return 0
            res=inf
            for i in range(l+1,r+1,2):
                cost=min(x,idx[i]-idx[l])
                res=min(res,cost+dp(l+1,i-1)+dp(i+1,r))
            return res
        
        return dp(0,len(idx)-1)

另種思路是把第一種操作拆成兩半

定義dp(i):需要反轉的索引陣列idx中,反轉前i個索引的最小成本。
可以選擇idx[i]要使用第一種操作,配對的idx[j]之後再找;或是idx[i]和idx[i-1]一起使用第二種操作。
轉移方程式:dp(i) = min( dp(i-1) + x/2, dp(i-2) + idx[i]-idx[i-1] )
base case:當i等於-1,代表反轉完成,回傳0。

注意:只有在剩下至少兩個索引的時候才能使用第二種操作,而且先前也過濾掉奇數情況,保證索引一定會成對。
然後x可能是奇數,所以dp結果必須是浮點數,最後才轉回整數。

時間複雜度O(N)。
空間複雜度O(N)。

class Solution:
    def minOperations(self, s1: str, s2: str, x: int) -> int:
        idx=[i for i,c in enumerate(s1) if c!=s2[i]]
        
        if len(idx)%2==1:
            return -1
        
        @cache
        def dp(i):
            if i<0:
                return 0
            res=dp(i-1)+x/2
            if i>0:
                res=min(res,dp(i-2)+idx[i]-idx[i-1])
            return res
        
        return int(dp(len(idx)-1))

如果怕浮點數誤差,可以把成本都先設成兩倍,最後回傳答案時一次除回來。

class Solution:
    def minOperations(self, s1: str, s2: str, x: int) -> int:
        idx=[i for i,c in enumerate(s1) if c!=s2[i]]
        
        if len(idx)%2==1:
            return -1
        
        @cache
        def dp(i):
            if i<0:
                return 0
            res=dp(i-1)+x
            if i>0:
                cost=idx[i]-idx[i-1]
                res=min(res,dp(i-2)+cost*2)
            return res
        
        return dp(len(idx)-1)//2

改寫成遞推版本。

class Solution:
    def minOperations(self, s1: str, s2: str, x: int) -> int:
        idx=[i for i,c in enumerate(s1) if c!=s2[i]]
        N=len(idx)
        
        if N%2==1:
            return -1

        dp=[0]*(N+1)
        for i in range(N):
            dp[i+1]=dp[i]+x
            if i>0:
                cost=idx[i]-idx[i-1]
                dp[i+1]=min(dp[i+1],dp[i-1]+cost*2)
                
        return dp[N]//2

dp[i]只會參考到dp[i-1]和dp[i-2]這前兩項的結果。
可以使用滾動陣列,節省空間。

注意:需要特判s1等於s2的情形。

時間複雜度O(N)。
空間複雜度O(1)。

class Solution:
    def minOperations(self, s1: str, s2: str, x: int) -> int:
        idx=[i for i,c in enumerate(s1) if c!=s2[i]]
        
        N=len(idx)
        if N%2==1:
            return -1
        
        if N==0: # s1==s2
            return 0
        
        pprev=0
        prev=x
        for i in range(1,N):
            cost=idx[i]-idx[i-1]
            t=prev
            prev=min(prev+x,pprev+cost*2)
            pprev=t
        
        return prev//2

其實還有一種O(N^2)的dp,但我覺得比起上面兩種更難想到。
判定當前索引i是否被反轉過,更類似於dfs的暴力搜索。

定義dp(i,free,rev):剩下free次免費反轉機會,使前i個字元相等的最小反轉成本。rev代表s1[i]是否被反轉過。

轉移方程式:
如果s1[i]和s2[i]原本就相同、或經過反轉後相同,則不需多餘操作,直接考慮下一個字元

  • dp(i,free,rev) = dp(i-1,free,False)

使用第一種操作,先計算費用,並記錄一次反轉的機會

  • dp(i,free,rev) = dp(i-1,free+1,False) + x

如果當前有免費機會也可以使用

  • dp(i,free,rev) = dp(i-1,free-1,False) if free>0

否則使用第二種操作,把s1[i]和s1[i-1]一起反轉

  • dp(i,free,rev) = dp(i-1,free,True) + 1

base cases:當i小於0時,所有字元都處理完,這時free應為0、rev也為false,合法則回傳0;否則因操作1需要成對操作,剩餘奇數free代表沒有答案,非零0的偶數代表使用太多的操作次數,兩種都不會是答案,回傳inf。

時間複雜度O(N^2)。
空間複雜度O(N^2)。

class Solution:
    def minOperations(self, s1: str, s2: str, x: int) -> int:
        N=len(s1)
        if s1.count("1")%2!=s2.count("1")%2:
            return -1
        
        @cache
        def dp(i,free,rev):
            if i<0 and free==0 and not rev:
                return 0
            if i<0:
                return inf
            if (s1[i]==s2[i]) == (not rev): # no need flip
                return dp(i-1,free,False)
            res=dp(i-1,free+1,False)+x # op1 
            res=min(res,dp(i-1,free,True)+1) # op2
            if free>0:
                res=min(res,dp(i-1,free-1,False))
            return res
        
        return dp(N-1,0,False)