[BZOJ3072]-[Pa2012]Two Cakes-dp有效状态+记搜

说在前面

为什么这么热啊!!!简直受不了啊QAQ
舍不得开空调,感觉浪费电,又没有风扇,要死……


题目

BZOJ3072传送门

题目大意

给出两个长度为 n 排列,现在需要把这两个排列按顺序抄一遍
你可以左手右手分别写一个排列,但是同一时刻左右手写的数字不能相同(如果相同了,就只能先写其中一个)
每写一个数字消耗 1 单位时间,询问最快多久可以写完
范围: n 10 6

输入输出格式

输入格式:
第一行一个整数 n
接下来 n 个整数表示第一个排列
再接下来 n 个整数表示第二个排列

输出格式:
输出一行一个数字,表示答案


解法

感觉还是比较有意思的
首先 n 2 d p 是显然的: d p [ i ] [ j ] m i n ( d p [ i 1 ] [ j ] , d p [ i ] [ j 1 ] ) + 1
如果 a i b j ,那么 d p [ i ] [ j ] d p [ i 1 ] [ j 1 ] + 1

考虑对这个 d p 进行简化
如果遇见相同的数字,肯定有一个排列延后一步,这个延后我们完全没有必要提前进行(也就是只有恰好 a i = b j 的时候才延后),可以证明是不影响答案的

那么,如果 a i = b j ,那么就像上面的第一个式子那样转移
不然就 d p [ i ] [ j ] d p [ i t ] [ j t ] + t ,其中 a i t + k b j t + k , 1 k t a i t = b j t
可以发现有用的状态就只有 n 个,也就是当 a i = b j 时的 d p [ i ] [ j ]

那么我们可以记忆化有用状态,对于其它的状态我们只需要快速得知 t 是多少就可以了
那么如何快速计算 t 呢?我们可以采用二分答案!我们可以找出所有的「在两个序列中位置差为 i j 的数字」,然后按照在 a 中出现的位置排序。因为 a i t = b j t ,那么 i t 也就是在「位置差为 i j 的数组」中,第一个小于 i 的数字

对于每个有用状态,需要用两个转移去计算,每个转移需要 log 2 n 的时间去二分 t ,所以总时间复杂度 Θ ( n log 2 n )
然后这道题就做完了


下面是自带大常数的代码

#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;

int N , a[1000005] , b[1000005] , arcB[1000005] , dp[1000005] ;
vector<int> dif[2000005] ;

void read_( int &x ){
    x = 0 ; char ch = getchar() ;
    while( ch < '0' || ch > '9' ) ch = getchar() ;
    while( ch >='0' && ch <='9' ) x = (x<<1) + (x<<3) + ch - '0' , ch = getchar() ;
}

inline int getPre( int dt , int pos ){
    int lf = 0 , rg = dif[dt].size() - 1 , rt = 0 ;
    while( lf <= rg ){
        int mid = ( lf + rg ) >> 1 ;
        if( dif[dt][mid] <= pos ) rt = dif[dt][mid] , lf = mid + 1 ;
        else rg = mid - 1 ;
    } return rt ;
}

int dfs( int i , int j ){
    if( !i || !j ) return i|j ;
    if( a[i] == b[j] ){
        if( !dp[i] ) dp[i] = min( dfs( i - 1 , j ) , dfs( i , j - 1 ) ) + 1 ;
        return dp[i] ;
    } int pre = getPre( i - j + N , i ) ;
    return pre ? dfs( pre , j - i + pre ) + i - pre : max( i , j ) ;
}

int main(){
    scanf( "%d" , &N ) ;
    for( int i = 1 ; i <= N ; i ++ ) read_( a[i] ) ;
    for( int i = 1 ; i <= N ; i ++ )
        read_( b[i] ) , arcB[ b[i] ] = i ;
    for( int i = 1 ; i <= N ; i ++ )
        dif[ i - arcB[a[i]] + N ].push_back( i ) ;
    printf( "%d" , dfs( N , N ) ) ;
}

猜你喜欢

转载自blog.csdn.net/izumi_hanako/article/details/80339273