题意:
给你n个字符串,任意选两个字符串,答案加上前一个字符串前缀和后一个字符串的后缀最长相同的部分长度的平方。问你答案是多少
题解:
想练一下AC自动机,就用这个写了,有时间的话再用广义后缀自动机写一下。
首先我们肯定是要枚举所有的字符串,问题是将它当做前缀还是后缀。把当前的字符串当成后缀的话,你的fail指针是这么跳的时候,会很难维护是否有的字符串的前缀已经被访问过了:
那么这种请况我一下子还想不出来怎么解决,有可能要用lca什么的吗?然后就换了一种思路:把当前枚举的字符串当成前缀,去匹配所有字符串的后缀,这个时候ac自动机insert的时候,num是要在当前字符串的末尾端点才能+1,表示到了末尾,然后所有字符串insert之后,对fail边建反边,于是就形成了一颗树,此时可以dfs去把num加到前面的后缀中。
然后枚举每个字符串的时候,也是要求出它的nex数组,然后枚举的时候,nex[i+1]就是当前串的最长公共前后缀,后面的情况剪掉前面的重复情况即可。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=1e6+5,R=26;
const ll mod=998244353;
bool vis[N];
string s[N];
int nex[N];
void KMP(int x){
int len=s[x].length();
nex[0]=-1,nex[1]=0;
int i=1,j=0;
while(i<len&&j<len){
if(j==-1||s[x][i]==s[x][j])nex[++i]=++j;
else j=nex[j];
}
}
struct Tire{
int nxt[N][R],fail[N],ed[N],num[N],vis[N];
vector<int>son[N];
int rt,tot,cnt;
int newnode(){
for(int i=0;i<R;i++)nxt[tot][i]=-1;
//l[tot]=0;
return tot++;
}
void init(){
memset(num,0,sizeof(num));
tot=cnt=0;
rt=newnode();
}
int insert(int x){
int now=rt,len=s[x].length();
for(int i=0;i<len;i++){
int val=s[x][i]-'a';
if(nxt[now][val]==-1)nxt[now][val]=newnode();
//l[nxt[now][val]]=l[now]+1;
now=nxt[now][val];
}
num[now]++;
return now;
}
void dfs(int x){
for(int i:son[x]){
dfs(i);
num[x]+=num[i];
}
}
void build(){
queue<int>q;
fail[rt]=rt;
for(int i=0;i<R;i++){
if(nxt[rt][i]==-1)nxt[rt][i]=rt;
else {
fail[nxt[rt][i]]=rt;
q.push(nxt[rt][i]);
son[rt].push_back(nxt[rt][i]);
}
}
while(!q.empty()){
int now=q.front();q.pop();
for(int i=0;i<R;i++){
if(nxt[now][i]==-1)nxt[now][i]=nxt[fail[now]][i];
else {
fail[nxt[now][i]]=nxt[fail[now]][i];
son[nxt[fail[now]][i]].push_back(nxt[now][i]);
q.push(nxt[now][i]);
}
}
}
dfs(0);
}
ll query(int x){
int now=rt,len=s[x].length();
ll ans=0;
for(ll i=0;i<len;i++){
now=nxt[now][s[x][i]-'a'];
ans=(ans+1ll*num[now]*(i+1)%mod*(i+1)-1ll*num[now]*nex[i+1]%mod*nex[i+1])%mod;
if(ans<0)ans+=mod;
}
return ans;
}
}ac;
int p[N];
int main()
{
cin.tie(0);
ios::sync_with_stdio(false);
int n;
//cin>>s[1];
//KMP(1);
cin>>n;
ac.init();
for(int i=1;i<=n;i++)
cin>>s[i],ac.insert(i);
ac.build();
ll ans=0;
for(int i=1;i<=n;i++)
KMP(i),ans=(ans+ac.query(i))%mod;
printf("%lld\n",ans);
return 0;
}