题解:
这道题主要用到的几个性质(具体证明可以看题解):
1.弱双回文串
的某个双回文划分
,满足
是
的最长回文前缀,或者
是
的最长回文后缀。
2.弱双回文串
若有两个弱回文划分,则
为整周期串。
3.弱双回文串
的周期为
,则其有
个不同的弱回文划分。
我们先计算本质不同的双回文串划分,然后减去算重的。
首先考虑如何计算本质不同的双回文串划分,我们先正反建出两颗回文树,然后相当于是有 个点对 ,然后要统计合法的 ,满足 在 的子树中且 在 的子树中,这个可以线段树合并+ lca做到 统计。
然后考虑怎么减去算重的,一个双回文串 被重复统计,则由 得其为整周期串,假设其周期为 ,则会被多算 或 次(分最小循环节是否为回文串讨论)。
跟这道题类似,由Runs Theorem,我们可以 提取出所有的本原平方串,且本质不同的本原平方串只有 个,可以直接暴力+哈希存下来,这个时候我们发现对每个本原平方串只需存他最多被循环多少次就行了。
然后考虑对于每个本原平方串减去其算重的部分,首先如果他循环两次不是双回文串的话,那么这个本原平方串就没用了,具体怎么判断需要用性质 。否则由性质 ,设其最多循环 次,他会被重复计算 次(如果自己本身是回文串,则是 次),这个可以 计算。
可以发现,我们在 的时间复杂度内解决了本题。
#include <bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/hash_policy.hpp>
using namespace std;
typedef long long LL;
typedef pair <int,int> pii;
const int N=5e5+50, L=21;
int n,lg[N*2];
char ch[N];
LL ans;
struct PAM {
int dep[N],fail[N],son[N][26],len[N],pos[N],last,tot;
vector <int> edge[N];
int dfn[N],st[L][N*2],id[N],ocr[N],fa[L][N],ind,cnt;
PAM() {
fail[0]=1, len[1]=-1, last=tot=1;
}
inline void dfs(int x,int f) {
fa[0][x]=f;
for(int i=1;i<L;i++) {
if(~fa[i-1][x]) fa[i][x]=fa[i-1][fa[i-1][x]];
else fa[i][x]=-1;
}
id[dfn[x]=++ind]=x;
st[0][ocr[x]=++cnt]=ind;
for(auto v:edge[x])
dfs(v,x), st[0][++cnt]=dfn[x];
}
inline int get_fail(int p,int i) {
while(ch[i-len[p]-1]!=ch[i]) p=fail[p];
return p;
}
inline void init() {
for(int i=1;i<=n;i++) {
int c=ch[i]-'a';
int p=get_fail(last,i);
if(!son[p][c]) {
++tot;
len[tot]=len[p]+2;
fail[tot]=son[get_fail(fail[p],i)][c];
dep[tot]=dep[fail[tot]]+1;
son[p][c]=tot;
}
last=son[p][c];
pos[i]=last;
}
edge[1].push_back(0);
for(int i=2;i<=tot;i++)
edge[fail[i]].push_back(i);
for(int i=0;i<L;i++) fa[i][1]=-1;
dfs(1,-1);
for(int i=1;i<L;i++)
for(int j=1;j+(1<<i)-1<=cnt;++j)
st[i][j]=min(st[i-1][j],st[i-1][j+(1<<(i-1))]);
}
inline int lca(int x,int y) {
x=ocr[x], y=ocr[y];
if(x>y) swap(x,y);
int l=lg[y-x+1];
return id[min(st[l][x],st[l][y-(1<<l)+1])];
}
inline int up(int x,int l) {
for(int i=L-1;~i;i--)
if(~fa[i][x] && len[fa[i][x]]>l) x=fa[i][x];
if(!~fa[0][x]) return 0;
return len[fa[0][x]];
}
} pam_ori,pam_rev;
struct SA {
int m,sa[N],a[N],b[N],*rk=a,*sa2=b,c[N],h[L][N];
inline void Rsort() {
for(int i=1;i<=m;i++) c[i]=0;
for(int i=1;i<=n;i++) ++c[rk[i]];
for(int i=1;i<=m;i++) c[i]+=c[i-1];
for(int i=n;i>=1;i--) sa[c[rk[sa2[i]]]--]=sa2[i];
}
inline void init() {
for(int i=1;i<=n;i++) sa2[i]=i, rk[i]=ch[i]-'a'+1;
m=26; Rsort();
for(int w=1,p=0;w<=n;w<<=1) {
for(int i=n-w+1;i<=n;i++) sa2[++p]=i;
for(int i=1;i<=n;i++) if(sa[i]>w) sa2[++p]=sa[i]-w;
Rsort(); swap(rk,sa2); rk[sa[1]]=p=1;
for(int i=2;i<=n;i++)
rk[sa[i]]=(sa2[sa[i]]==sa2[sa[i-1]] && sa2[sa[i]+w]==sa2[sa[i-1]+w]) ? p : ++p;
if(p==n) break; m=p; p=0;
}
for(int i=1,k=0,j;i<=n;h[0][rk[i++]]=k)
for(k?k--:k,j=sa[rk[i]-1];ch[i+k]==ch[j+k];++k);
for(int i=1;i<L;i++)
for(int j=1;j+(1<<i)-1<=n;++j)
h[i][j]=min(h[i-1][j],h[i-1][j+(1<<(i-1))]);
}
inline int lcp(int x,int y) {
x=rk[x], y=rk[y];
if(x>y) swap(x,y);
++x; int l=lg[y-x+1];
return min(h[l][x],h[l][y-(1<<l)+1]);
}
} sa_ori,sa_rev;
const int md1=1e9+7, md2=1e9+9;
struct hval {
int x,y;
hval() {}
hval(int x,int y) : x(x),y(y) {}
friend inline bool operator ==(const hval &a,const hval &b) {return a.x==b.x && a.y==b.y;}
friend inline hval operator +(const hval &a,const hval &b) {return hval((a.x+b.x)%md1,(a.y+b.y)%md2);}
friend inline hval operator -(const hval &a,const hval &b) {return hval((a.x-b.x+md1)%md1,(a.y-b.y+md2)%md2);}
friend inline hval operator *(const hval &a,const hval &b) {return hval((LL)a.x*b.x%md1,(LL)a.y*b.y%md2);}
} pw[N];
const hval base=hval(31,37);
struct HASH {
hval val[N];
inline void init() {
for(int i=1;i<=n;i++)
val[i]=val[i-1]*base+hval(ch[i]-'a'+1,ch[i]-'a'+1);
}
inline LL gv(int l,int r) {
hval t=val[r]-val[l-1]*pw[r-l+1];
return (LL)md2*t.x*t.y;
}
} hs_ori,hs_rev;
inline bool pld(int l,int r) {
int len=(r-l+1)/2;
return hs_ori.gv(l,l+len-1)==hs_rev.gv(n-r+1,n-(r-len+1)+1);
}
inline bool valid(int l,int r) {
int len=pam_ori.up(pam_ori.pos[r],r-l);
if(len && pld(l,r-len)) return true;
len=pam_rev.up(pam_rev.pos[n-l+1],r-l);
if(len && pld(l+len,r)) return true;
return false;
}
namespace ST {
const int M=N*30;
int rt[N],lc[M],rc[M],mn[M],mx[M],tot; LL s[M];
vector <int> qry[N];
inline void upt(int k) {
if(!lc[k] || !rc[k]) {
int t=lc[k]+rc[k];
mn[k]=mn[t];
mx[k]=mx[t];
s[k]=s[t];
} else {
mn[k]=mn[lc[k]];
mx[k]=mx[rc[k]];
s[k]=s[lc[k]]+s[rc[k]]-pam_rev.dep[pam_rev.lca(pam_rev.id[mx[lc[k]]],pam_rev.id[mn[rc[k]]])];
}
}
inline void ins(int &k,int l,int r,int p) {
if(!k) k=++tot;
if(l==r) {
s[k]=pam_rev.dep[pam_rev.id[p]];
mn[k]=mx[k]=p;
return;
} int mid=(l+r)>>1;
(p<=mid) ? ins(lc[k],l,mid,p) : ins(rc[k],mid+1,r,p);
upt(k);
}
inline void merge(int &x,int y,int l,int r) {
if(!x) {x=y; return;}
if(!y) return;
if(l==r) return;
int mid=(l+r)>>1;
merge(lc[x],lc[y],l,mid);
merge(rc[x],rc[y],mid+1,r);
upt(x);
}
inline void dfs(int x,int f) {
for(auto v:pam_ori.edge[x])
dfs(v,x), merge(rt[x],rt[v],1,pam_rev.ind);
for(auto v:qry[x])
ins(rt[x],1,pam_rev.ind,pam_rev.dfn[v]);
if(pam_ori.len[x]>0) ans+=s[rt[x]];
}
inline LL solve() {
for(int i=1;i<n;i++)
qry[pam_ori.pos[i]].push_back(pam_rev.pos[n-i]);
dfs(1,0);
}
}
namespace PR {
const int T=1e7+50;
int si[T],cov[T];
struct atom{
int l,t,d;
atom() {}
atom(int l,int t,int d) : l(l),t(t),d(d) {}
};
__gnu_pbds::gp_hash_table <LL,atom> pr;
inline void solve() {
for(int i=1;i<=n;i++) si[i]=si[i-1]+n/i+1;
for(int i=1;i<=n;i++) {
int l=0, r=0;
for(int j=i;j+i<=n;j+=i) if(j>r) {
int L=sa_rev.lcp(n-j+1,n-(j+i)+1), R=sa_ori.lcp(j,j+i);
l=j-L+1, r=j+R-1;
if(r-l+1<i) continue;
if(cov[si[i-1]+j/i]) continue;
for(int x=i+i;l+x+x-1<=r+i;x+=i)
cov[si[x-1]+(l+x-1)/x]=1;
for(int k=l;k+i-1<=r && k-l<i;k++) {
LL hv=hs_ori.gv(k,k+i-1);
int rep=(r+1-k)/i+1;
if(pr.end()==pr.find(hv)) pr[hv]=atom(k,rep,i);
else if(pr[hv].t<rep) pr[hv]=atom(k,rep,i);
}
}
}
for(auto v:pr) {
atom &p=v.second;
if(!valid(p.l,p.l+p.d*2-1)) continue;
LL tp=(LL)p.t*(p.t-1)/2;
if(pld(p.l,p.l+p.d-1)) tp-=p.t-1;
ans-=tp;
}
}
}
int main() {
scanf("%s",ch+1); n=strlen(ch+1);
pw[0]=hval(1,1);
for(int i=1;i<=n;i++) pw[i]=pw[i-1]*base;
for(int i=2;i<=n*2;i++) lg[i]=lg[i>>1]+1;
pam_ori.init(), sa_ori.init(), hs_ori.init();
reverse(ch+1,ch+n+1);
pam_rev.init(), sa_rev.init(), hs_rev.init();
ST::solve();
PR::solve();
cout<<ans<<'\n';
}