[bzoj3238] [AHOI2013]差异

Description

img

Input

一行,一个字符串S

Output

一行,一个整数,表示所求值

Sample Input

cacao

Sample Output

54

Solution

把式子拆成两部分处理,第一部分就是:
\[ \sum_{i=1}^{n}\sum_{j=i+1}^{n}i+j=\frac{(n-1)n(n+1)}{2} \]
直接算就好了。

对于后面一部分,也就是要求任意两个后缀的\(lcp\)的和。

把串翻转,那么就变成了求任意两个前缀的最长公共后缀之和。

那么对翻转后的串建后缀自动机,对于每个前缀在后缀自动机上都是不同的点,显然有一个这样的性质:

  • 对于两个前缀\(T_1,T_2\),设他们代表的点为\(u,v\),那么最长公共后缀就是\(parent\)树上的\(lca(u,v)\)\(maxl\),也就是最大可扩展的长度。

考虑下\(parent\)树的性质就能得到这个东西。

然后问题就变成了对于每个点问子树内\(lca\)为它的点对有多少个,然后乘上这个点的\(maxl\)就是\(lcp\)的和。

这就是一个很基本的问题,随便搞搞就好了。

#include<bits/stdc++.h>
using namespace std;
 
void read(int &x) {
    x=0;int f=1;char ch=getchar();
    for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
    for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
 
#define ll long long

void print(ll x) {
    if(x<0) putchar('-'),x=-x;
    if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(ll x) {if(!x) putchar('0');else print(x);putchar('\n');}

const int maxn = 1e6+10;

ll ans;
int tr[maxn][27],fa[maxn],ml[maxn],lstp=1,qs=1,cnt=1,sz[maxn];

void append(int c) {
    int p=lstp,np=++cnt;ml[np]=ml[p]+1;sz[np]=1;lstp=np;
    for(;p&&tr[p][c]==0;p=fa[p]) tr[p][c]=np;
    if(p==0) return fa[np]=qs,void();
    int q=tr[p][c];
    if(ml[p]+1<ml[q]) {
        int nq=++cnt;ml[nq]=ml[p]+1;
        memcpy(tr[nq],tr[q],sizeof tr[nq]);
        fa[nq]=fa[q],fa[q]=fa[np]=nq;
        for(;p&&tr[p][c]==q;p=fa[p]) tr[p][c]=nq;
    }else fa[np]=q;
}

struct edge{int to,nxt;}e[maxn<<1];
int head[maxn],tot;

void add(int u,int v) {e[++tot]=(edge){v,head[u]},head[u]=tot;}
void ins(int u,int v) {add(u,v),add(v,u);}

void dfs(int x,int f) {
    ll res=0,t=sz[x];
    for(int i=head[x];i;i=e[i].nxt)
        if(e[i].to!=f) dfs(e[i].to,x),sz[x]+=sz[e[i].to];
    for(int i=head[x];i;i=e[i].nxt)
        if(e[i].to!=f) res+=1ll*sz[e[i].to]*(sz[x]-sz[e[i].to]-t);
    res>>=1ll,res+=t*(sz[x]-1),ans-=1ll*res*ml[x]*2;
}

char s[maxn];

int main() {
    scanf("%s",s+1);int n=strlen(s+1);
    for(int i=1;i<=n>>1;i++) swap(s[i],s[n-i+1]);
    ans=1ll*(n-1)*n*(n+1)>>1ll;
    for(int i=1;i<=n;i++) append(s[i]-'a'+1);
    for(int i=2;i<=cnt;i++) if(fa[i]) ins(i,fa[i]);
    dfs(1,0);write(ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/hbyer/p/10433065.html
今日推荐