题目
题解
这里:把两个串用一个很大的字符连接起来,求一个后缀数组。
考虑怎样暴力的算答案。
在 rank r a n k 数组中从前往后枚举起点,对于每个枚举的起点,都暴力的往后扫,扫的过程中维护一个 height h e i g h t 的最小值。每到一个点的时候,如果这个点跟起点不属于一个串,就将答案加上当前的最小值,这样是O(n2)的考虑这个还能怎么算。可以发现我们是维护 height h e i g h t 的最小值。那么我们可以按照 height h e i g h t 从大到小的顺序扫,这样每次需要用的就是当前的 height h e i g h t 。
扫的过程中用并查集维护一下每个串分别对哪些串有贡献的(也就是 height h e i g h t 数组的贡献)。
用乘法原理算一下当前的 height h e i g h t 会有多少贡献。就是用当前的 height h e i g h t 乘上这个串和上一个串分别对于两个两个不同的原串的乘积的和。
代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define ll long long
const int maxn=1e6;
const int inf=1e9;
int n,len1,m=200,x[maxn],y[maxn],c[maxn],sa[maxn],rnk[maxn],height[maxn],fa[maxn],id[maxn];
int st[maxn],ed[maxn];
char s[maxn],ch[maxn];
ll ans;
void build_sa()
{
for (int i=0; i<m; i++) c[i]=0;
for (int i=0; i<n; i++) c[x[i]=s[i]]++;
for (int i=1; i<m; i++) c[i]+=c[i-1];
for (int i=n-1; i>=0; i--) sa[--c[x[i]]]=i;
for (int k=1; k<=n; k<<=1)
{
int p=0;
for (int i=n-k; i<n; i++) y[p++]=i;
for (int i=0; i<n; i++) if (sa[i]>=k) y[p++]=sa[i]-k;
for (int i=0; i<m; i++) c[i]=0;
for (int i=0; i<n; i++) c[x[i]]++;
for (int i=1; i<m; i++) c[i]+=c[i-1];
for (int i=n-1; i>=0; i--) sa[--c[x[y[i]]]]=y[i];
swap(x,y);
p=1; x[sa[0]]=0;
for (int i=0; i<n; i++)
x[sa[i]] = y[sa[i-1]]==y[sa[i]] && ((sa[i-1]+k>=n?-1:y[sa[i-1]+k])==(sa[i]+k>=n?-1:y[sa[i]+k])) ?p-1:p++;
if (p>n) break;
m=p;
}
}
void build_height()
{
int k=0;
for (int i=0; i<n; i++) rnk[sa[i]]=i;
for (int i=0; i<n; i++)
{
if (!rnk[i]) continue;
if (k) k--;
int j=sa[rnk[i]-1];
while (i+k<n && j+k<n && s[i+k]==s[j+k]) k++;
height[rnk[i]]=k;
}
}
bool cmp(int x,int y) {
return height[x]>height[y];}
int find(int x)
{
if (fa[x]!=x) return fa[x]=find(fa[x]);
return fa[x];
}
void work(int x)
{
int xx=find(x); int yy=find(x-1);
ans+=(ll)(st[xx]*ed[yy]+st[yy]*ed[xx])* (ll)height[x];
// printf("r1 %d r2 %d :(%d*%d + %d*%d) * %d\n",xx,yy,st[xx],ed[yy],st[yy],ed[xx],height[x]);
st[xx]+=st[yy];
ed[xx]+=ed[yy];
// printf("%d %d\nans:%d\n",st[xx],ed[xx],ans);
fa[yy]=xx;
}
int main()
{
scanf("%s%s",s,ch);
len1=strlen(s); n=len1+strlen(ch); n++;
s[len1]='#';
for (int i=0; i<strlen(ch); i++)
s[len1+1+i]=ch[i];
build_sa();
build_height();
for (int i=0; i<n; i++) id[i]=i;
for (int i=0; i<n; i++) fa[i]=i;
for (int i=0; i<n; i++)
{
st[i]=sa[i]<len1;
ed[i]=st[i]^1;
}
sort(id,id+n,cmp);
for (int i=0; i<n; i++) work(id[i]);
printf("%lld",ans);
return 0;
}