Description
定义, \(f(T, s)\) 为字符串 \(s\) 在字符串 \(T\) 中出现的次数,如:\(f(\texttt{"aaabacaa"}, \texttt{"aa"}) = 3, f(\texttt{"ababa"}, \texttt{"aba"}) = 2\)。
给定一个字符串 \(t\), \(n\) 个字符串 \(s_1, s_2, \cdots, s_n\)。求
的值,其中 \(s_i + s_j\) 表示将字符串 \(s_i, s_j\) 拼接在一起。
Hint
- \(1\le |t| \le 2\times 10^5\)
- \(1\le n\le 2\times 10^5\)
- \(1\le \sum_{i = 1}^{n} |s_i| \le 2\times 10^5\)
Solution
注:下文中 \(s[l\cdots r]\) 表示字符串 \(s\) 的第 \(l\) 到第 \(r\) 位形成的子串。
设 \(f_k\) 表示有几个 \(\{s_1, s_2,\cdots s_n\}\) 中的字符串是 \(t[1\cdots k]\) 的后缀 ,设 \(g_k\) 表示有几个 \(\{s_1, s_2,\cdots s_n\}\) 中的字符串是 \(t[k\cdots |t|]\) 的前缀。
若我们确定一个断点 \(x\),将 \(t\) 分为前后两部分 \(t_1 = t[1\cdots x], t_2 = t[x + 1\cdots |t|]\) 。找有几个 \(\{s_1, s_2,\cdots s_n\}\) 中的字符串可以拼在 \(t_1\) 的后面,有几个可以拼在 \(t_2\) 的前面,将这两个个数一乘(乘法原理),就是这个断点对答案的贡献。
于是答案就是:
接下来就是如何求 \(f, g\) 了。
先建 AC自动机
。以 \(f\) 为例, \(g\) 反转后同理。
当文本串 \(t\) 在自动机上走的时候,每走到一个结点,都要往 fail
指针的一直方向跳,沿途记录有几个模式串的结尾作为这一位的答案。但这样的复杂度有问题,如果遇到形如 aaaaa....aa
的就会被卡成 \(O(\sum |s_i| \times |t|)\),绝对 T 飞。
但是我们可以在建 fail
指针的时候就把这个结点 一直通过 fail
连到最上面有几个结尾点 处理出来就不会有这样的问题。具体地,只要在一个结点做好之后,加一句 t[x].cnt += t[t[x].fail].cnt;
即可,类似于前缀和。这样一直累加下去,最后 cnt
就从结尾的个数变成了一路跳 fail
的结尾的个数。
时间复杂度:\(O(\sum|s_i| + |t|)\)。
Code
#include <algorithm>
#include <iostream>
#include <string>
#include <queue>
using namespace std;
const int L = 2e5 + 5;
const int S = 26;
struct AC_Automaton {
struct ACAM_Node {
int ch[S]; // 子结点
int fail; // fail 指针
int cnt;
/* cnt 的前后两种含义 :
* 建fail前:结尾个数
* 建fail后:一路跳fail一共的结尾个数
*/
} t[L];
int total;
AC_Automaton() {
total = 0;
}
inline void insert(string& s) {
int x = 0;
for (string::iterator it = s.begin(); it != s.end(); it++) {
int c = *it - 'a';
if (!t[x].ch[c]) t[x].ch[c] = ++total;
x = t[x].ch[c];
}
t[x].cnt++; // 结尾符标记
}
inline void initFail() { //建 AC自动机
queue <int> Q;
for (register int c = 0; c < S; c++)
if (t[0].ch[c]) Q.push(t[0].ch[c]), t[t[0].ch[c]].fail = 0;
while (!Q.empty()) {
int x = Q.front(); Q.pop();
for (register int c = 0; c < S; c++)
if (t[x].ch[c]) {
Q.push(t[x].ch[c]);
t[t[x].ch[c]].fail = t[t[x].fail].ch[c];
} else t[x].ch[c] = t[t[x].fail].ch[c];
t[x].cnt += t[t[x].fail].cnt; // 在这里加一个前缀和的操作,当前的 cnt 加上 fail 对应的 cnt。
}
}
inline void scan(int *f, string& s) { // 计算 f 和 g
int x = 0;
for (string::iterator it = s.begin(); it != s.end(); it++, f++)
x = t[x].ch[*it - 'a'], *f = t[x].cnt;
// 预处理过,不需要暴跳fail了,直接取值。
}
} A, R; // 正AC自动机,反AC自动机
int f[L], g[L];
signed main() {
ios::sync_with_stdio(false);
string s, t; int n;
cin >> t >> n;
for (; n; --n) {
cin >> s;
A.insert(s);
reverse(s.begin(), s.end());
R.insert(s); // 反转后插入反自动机
}
A.initFail();
A.scan(f + 1, t);
reverse(t.begin(), t.end());
R.initFail(); // 反转后在反自动机上跑
R.scan(g + 1, t);
long long ans = 0ll; // 不开 long long 见祖宗
for (register int i = 1; i <= t.size(); i++)
ans += f[i] * 1ll * g[t.size() - i];
// 乘法原理,注意这里的 g 也反了。原来应为 ans += f[i] * 1ll * g[i + 1]
cout << ans << endl;
return 0;
}