给一个串,问这个串里所有本质不同的回文子串,有多少对满足一个串是另一个的子串。
这题现场过的人很少啊,题解也给了个蛮复杂我还没看懂的带log的做法,其实了解回文树的话特别好想,我们现场写了一个O(n)的做法(在牛客跑了72ms)。回文树还算是个新东西,还没有被玩坏,我以前刷的回文树套题基本都算是板子题,最近多校有几道回文树就进入了灵活运用的范畴了,出题人开始准备玩坏这个算法了,以后这都是基操。
见本质不同,想回文自动机,然后就开始梳理在回文树上,怎样的两个节点表示的回文串有子串关系?
首先是next边,很显然,一条next链上,除奇根偶根,所有父亲都是儿子的子串。
其次是fail边,又又又很显然,一条fail链上,父亲是儿子的后缀,所以除奇根偶根所有父亲都是儿子的子串。
好了这题就快做完了。
对于每个节点,答案就是其next链上除根外的父亲个数+fail链上除根外的父亲个数。
然而仅仅这样还是存在问题,因为next链和fail链有时候存在交叉,也就是说一个回文串既是另一个串的后缀,又在那个串中间出现,比如样例的aaaa,next链上的aa与fail链的aa是同一个节点。
于是还要想办法去个重。在dfs并且对每个节点跳fail树统计的时候,打个vis标记(当前节点也要标记,next链和fail链重的时候就可能会跳到这里),如果下次跳的时候发现了vis,就可以停止了,这样就简单的保证了不交叉,并且防止了fail树重复跳带来的复杂度。
最后,对于跑出来的东西做一个很简单的dp:当前节点贡献 = 父亲贡献 + fail链贡献 + 1。
ac代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e5 + 5;
int _, kase = 1;
char s[maxn];
ll ans = 0;
struct Pam {
int next[maxn][26];
int fail[maxn];
int len[maxn];// 当前节点表示回文串的长度
int S[maxn];
int dp[maxn];
bool vis[maxn];
int last, n, p;
int newNode(int l) {
memset(next[p], 0, sizeof(next[p]));
len[p] = l;
dp[p] = 0;
return p++;
}
void init() {
ans = 0;
n = last = p = 0;
newNode(0);
newNode(-1);
S[n] = -1;
fail[0] = 1;
}
int getFail(int x) {
while (S[n - len[x] - 1] != S[n]) {
x = fail[x];
}
return x;
}
void add(int c) {
S[++n] = c;
int cur = getFail(last);
if (!next[cur][c]) {
int now = newNode(len[cur] + 2);
fail[now] = next[getFail(fail[cur])][c];
next[cur][c] = now;
}
last = next[cur][c];
}
int jump(int x) {
int cnt = 0;
vis[x] = 1;
while (fail[x] != 0 && fail[x] != 1 && !vis[fail[x]]) {
x = fail[x];
vis[x] = 1, ++cnt;
}
return cnt;
}
void clearJump(int x, int cnt) {
vis[x] = 0;
while (cnt--) {
x = fail[x];
vis[x] = 0;
}
}
void dfs(int x, int fa) {
int jp = jump(x);
dp[x] = jp;
if (x != 1 && x != 0 && fa != 0 && fa != 1) {
dp[x] = dp[fa] + jp + 1;
}
ans += dp[x];
for (int i = 0; i < 26; ++i) {
if (next[x][i]) {
dfs(next[x][i], x);
}
}
clearJump(x, jp);
}
void build() {
init();
for (int i = 1; s[i]; i++) {
add(s[i] - 'a');
}
}
} pam;
int main() {
scanf("%d", &_);
while (_--) {
scanf("%s", s + 1);
pam.build();
printf("Case #%d: ", kase++);
pam.dfs(1, 1);
// printf("%lld\n", ans);
pam.dfs(0, 0);
printf("%lld\n", ans);
}
return 0;
}