20181229模拟 T1 palindrome

20181229模拟 T1 palindrome

题意 :

\(S\)是字符串\(s\)的子串可重集,求\(\sum\limits_{x\in S}\sum\limits_{y\in S}(|x|+|y|)\times [xy\ is \ palidrome]mod\ 2013265921\)

分析:

\(2013265921\)的原根是\(31\),所以这道题我使用后缀自动机+回文树来解决。

注意到一个由两个字符串所组成的回文串\(xy\),不妨设\(|x|<|y|\)\(y\)显然是由一个回文串和一个\(x\)的反串组成。

于是我们可以枚举回文串的结尾\(i\),显然向右能被反串匹配的是一段区间,向左能匹配的所有回文串就是一直跳回文树上\(fail\)能到达的那些结点,求出此时向右匹配反串的种类\(c1\)和总长度\(s1\)向左匹配回文串的个数\(c2\)和总长度\(s2\),那么答案就是$\sum\limits_{i=1}^{n-1}c1_{i+1}\times s2_i+c2_i\times s1_{i+1} $。 然后这两个用后缀自动机+回文树即可完美解决。

然后不要忘记处理回文串在右反串在左的情况,我的做法是将整个字符串反过来重新求一遍。

最后需要加上回文串长度为\(0\)的方案。

code:

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
#define N 1000050
#define mem(x) memset(x,0,sizeof(x))
#define mod 2013265921
#define db(x) cerr<<#x<<" = "<<x<<endl
typedef long long ll;
int n;
char w[N];
ll s1[N],c1[N],c2[N],s2[N];
ll ss(ll l,ll r) {
    return (l+r)*(r-l+1)/2%mod;
}
struct Sam {
    int ch[N][26],fa[N],len[N],lst,cnt,ke[N],ro[N],siz[N];
    ll sum[N],sd[N];
    void init() {
        lst=cnt=1;
    }
    void insert(int x) {
        int p=lst,np=++cnt,q,nq; lst=np;
        len[np]=len[p]+1;
        for(;p&&!ch[p][x];p=fa[p]) ch[p][x]=np;
        if(!p) fa[np]=1;
        else {
            q=ch[p][x];
            if(len[q]==len[p]+1) fa[np]=q;
            else {
                nq=++cnt; len[nq]=len[p]+1;
                memcpy(ch[nq],ch[q],sizeof(ch[q]));
                fa[nq]=fa[q]; fa[q]=fa[np]=nq;
                for(;p&&ch[p][x]==q;p=fa[p]) ch[p][x]=nq;
            }
        }
        siz[lst]++;
    }
    void kero() {
        int i;
        for(i=1;i<=cnt;i++) ke[len[i]]++;
        for(i=1;i<=cnt;i++) ke[i]+=ke[i-1];
        for(i=cnt;i;i--) ro[ke[len[i]]--]=i;
        for(i=cnt;i>=1;i--) siz[fa[ro[i]]]+=siz[ro[i]];
        for(i=2;i<=cnt;i++) {
            int p=ro[i];
            sd[p]=(sd[fa[p]]+siz[p]*(len[p]-len[fa[p]]))%mod;
            sum[p]=(sum[fa[p]]+siz[p]*ss(len[fa[p]]+1,len[p]))%mod;
        }
    }
    void pipei() {
        int p=1,now=0,i;
        for(i=n;i;i--) {
            int x=w[i];
            if(ch[p][x]) {
                p=ch[p][x]; now++;
            }else {
                for(;p&&!ch[p][x];p=fa[p]) ;
                if(!p) {
                    p=1; now=0;
                }else {
                    now=len[p]+1; p=ch[p][x];
                }
            }
            s1[i]=(sum[fa[p]]+siz[p]*ss(len[fa[p]]+1,now))%mod,c1[i]=(siz[p]*(now-len[fa[p]])+sd[fa[p]])%mod;
        }
    }
    void clear() {
        mem(ch);mem(fa);mem(len);mem(ke);mem(ro);mem(siz);mem(sum);mem(sd);
        init();
    }
}sam;
struct Pam {
    int ch[N>>1][26],fail[N],len[N],cnt,lst,dep[N];
    ll sum[N];
    void init() {
        len[1]=-1; fail[0]=fail[1]=1; cnt=1; lst=0;
    }
    void insert(int i,int x) {
        int p=lst,np;
        for(;w[i-len[p]-1]!=x;p=fail[p]) ;
        if(!ch[p][x]) {
            np=++cnt;
            len[np]=len[p]+2;
            int q=fail[p];
            for(;w[i-len[q]-1]!=x;q=fail[q]) ;
            fail[np]=ch[q][x];
            ch[p][x]=np;
            
            dep[np]=dep[fail[np]]+1;
            sum[np]=(sum[fail[np]]+len[np])%mod;
        }
        lst=ch[p][x];
    }
    void wk() {
        int i;
        for(i=1;i<=n;i++) {
            insert(i,w[i]);
            c2[i]=dep[lst];
            s2[i]=sum[lst];
        }
    }
    void clear() {
        mem(ch);mem(fail);mem(len);mem(dep);mem(sum);
        init();
    }
}pam;
int main() {
    scanf("%s",w+1); n=strlen(w+1);
    w[0]=29;
    int i;
    ll ans=0;
    for(i=1;i<=n;i++) w[i]-='a';
    sam.init(); pam.init();
    for(i=1;i<=n;i++) sam.insert(w[i]);
    sam.kero();
    sam.pipei();
    pam.wk();
    for(i=1;i<=n;i++) s1[i]*=2;
    for(i=1;i<=n;i++) {
        ans+=(s1[i]*c2[i-1]+c1[i]*s2[i-1])%mod;
    }
    sam.clear();
    pam.clear();
    mem(s1);mem(s2);mem(c1);mem(c2);
    
    reverse(w+1,w+n+1);
    for(i=1;i<=n;i++) sam.insert(w[i]);
    sam.kero();
    sam.pipei();
    pam.wk();
    for(i=1;i<=n;i++) s1[i]*=2;
    for(i=1;i<=n;i++) {
        ans+=(s1[i]*(c2[i-1]+1)+c1[i]*s2[i-1])%mod;
    }
    printf("%lld\n",(ans+mod)%mod);
}

猜你喜欢

转载自www.cnblogs.com/suika/p/10240431.html