版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013578420/article/details/81330290
题目大意:
给你一个字符串s, 求有多少对(i, j),
, 使得交换s[i]和s[j]后, s可以被分割成两个回文串。
题目思路:
可以暴力的枚举从哪个点可以分割。
考虑[1, i] 和 [i+1,|s|]这一段。
首先可以用后缀数组, 求出s+rev(s)的后缀数组。
这样就可以O(1)的得出某两个位置往中间匹配的最长公共长度。
如果匹配长度大于了[l,r]的区间长, 说明这个子串是回文串。
如果两个串都是回文串, 有个标记记录一下, 最后的时候加上所有相同字符交换的答案。
否则, 记录下失配的两个端点, 跳过这两个端点继续匹配, 记录下各个失配的端点。
若果最终失配的端点组数大于2, 则是没救的, continue。
如果失配的端点组数等于2, 则必须得靠某组的一个端点与另一个组的一个端点交换。
若果失配的端点组数为1, 则必须靠某一个奇数串的中心来交换。
需要注意的是, 在两个串都是回文串的情况下, 即失配的端点组数为0, 依然有可能是两个奇数串的中心交换。
对于相同字符的交换情况, 可以用组合数直接算, 不同字符交换情况, 是O(n)的, 可以每次丢到set里暴力去重。
Code:
#include <map>
#include <set>
#include <map>
#include <bitset>
#include <cmath>
#include <queue>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define ll long long
#define db double
#define fi first
#define se second
#define mp(x, y) make_pair(x, y)
#define ls (x << 1)
#define rs ((x << 1) | 1)
#define mid ((l + r) >> 1)
using namespace std;
const int N = (int)2e5 + 10;
int Log[N * 2];
namespace SA{
const int M = 30;
char str[N * 2]; int n, m;
int c[N * 2], x[N * 2], y[N * 2], height[N * 2], rk[N * 2], sa[N * 2], st[N * 2][M];
void build_sa(){
m = 26;
for (int i = 1; i <= n; i ++) x[i] = str[i] - 'a' + 1;
for (int i = 1; i <= n; i ++) c[x[i]] ++;
for (int i = 2; i <= m; i ++) c[i] += c[i - 1];
for (int i = 1; i <= n; i ++) sa[c[x[i]] --] = i;
for (int i = 1; i <= m; i ++) c[i] = 0;
for (int k = 1; k < n; k <<= 1){
int p = 0;
for (int i = n - k + 1; i <= n; i ++) y[++ p] = i;
for (int i = 1; i <= n; i ++) if (sa[i] > k) y[++ p] = sa[i] - k;
for (int i = 1; i <= n; i ++) c[x[i]] ++;
for (int i = 2; i <= m; i ++) c[i] += c[i - 1];
for (int i = n; i >= 1; i --) sa[c[x[y[i]]] --] = y[i];
for (int i = 1; i <= m; i ++) c[i] = 0;
for (int i = 1; i <= n; i ++) y[i] = x[i];
p = 1;
x[sa[1]] = 1;
for (int i = 2; i <= n; i ++)
x[sa[i]] = y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k] ? p : ++ p;
for (int i = 1; i <= n; i ++) y[i] = 0;
if (p >= n) break;
m = p;
}
}
void call_height(){
for (int i = 1; i <= n; i ++) rk[sa[i]] = i, height[i] = 0;
int k = 0;
for (int i = 1; i <= n; i ++){
if (rk[i] == 1) continue;
for (k ? k -- : 0; str[i + k] == str[sa[rk[i] - 1] + k] && i + k <= n && sa[rk[i] - 1] + k <= n; k ++);
height[rk[i]] = k;
}
}
void build_st(){
for (int i = 1; i <= n; i ++)
st[i][0] = height[i];
for (int j = 1; j < M; j ++)
for (int i = 1; i <= n; i ++){
if (i + (1 << (j - 1)) > n) break;
st[i][j] = min(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
}
}
int lcp(int l, int r){
l = rk[l], r = rk[r];
if (l > r) swap(l, r);
l ++;
int k = Log[r - l + 1];
r = r - (1 << k) + 1;
return min(st[l][k], st[r][k]);
}
void main(char *s, int l){
for (int i = 1; i <= l; i ++)
str[i] = s[i];
for (int i = 1; i <= l; i ++)
str[l + i] = s[l - i + 1];
n = 2 * l;
str[n + 1] = 0;
build_sa();
call_height();
build_st();
}
}
int T;
char str[N]; int n, cnt[30];
ll ans; bool flag; set<pair<int, int> > S;
int rev(int i){return n * 2 - i + 1;}
void getpos(int l, int r, int &sz, int *lpos, int *rpos){
while (sz <= 2 && l < r){
int len = SA::lcp(l, rev(r));
if (len >= r - l + 1) break;
sz ++;
lpos[sz] = l + len, rpos[sz] = r - len;
l += (len + 1), r -= (len + 1);
}
}
void add(int i, int j){
if (i > j) swap(i, j);
S.insert(mp(i, j));
}
int main(){
for (int i = 1; (1 << i) < 2 * N; i ++)
Log[1 << i] = 1;
for (int i = 1; i < N * 2; i ++)
Log[i] += Log[i - 1];
scanf("%d\n", &T);
while (T --){
scanf("%s\n", str + 1);
n = strlen(str + 1);
flag = 0;
ans = 0; S.clear();
memset(cnt, 0, sizeof(cnt));
for (int i = 1; i <= n; i ++)
cnt[str[i] - 'a'] ++;
SA::main(str, n);
for (int i = 1; i <= 2 * n; i ++)
printf("%d ", SA::height[i]);
puts("");
for (int i = 1; i <= 2 * n; i ++)
printf("%d ", SA::sa[i]);
puts("");
for (int i = 1; i < n; i ++){
//[1, i] [i + 1, n]
int sz = 0, lpos[5], rpos[5];
//printf("%d\n", i);
getpos(1, i, sz, lpos, rpos);
getpos(i + 1, n, sz, lpos, rpos);
if (sz == 0) {
flag = 1;
if ((i & 1) && ((n - i) & 1)){
if (str[i / 2 + 1] != str[i + (n - i) / 2 + 1])
add(i / 2 + 1, i + (n - i) / 2 + 1);
}
continue;
}
if (sz > 2) continue;
if (sz == 1){
if (i & 1){
if (str[i / 2 + 1] == str[lpos[1]]) add(i / 2 + 1, rpos[1]);
if (str[i / 2 + 1] == str[rpos[1]]) add(i / 2 + 1, lpos[1]);
}
if ((n - i) & 1){
if (str[i + (n - i) / 2 + 1] == str[lpos[1]]) add(i + (n - i) / 2 + 1, rpos[1]);
if (str[i + (n - i) / 2 + 1] == str[rpos[1]]) add(i + (n - i) / 2 + 1, lpos[1]);
}
continue;
}
if (sz == 2){
if (str[lpos[1]] == str[lpos[2]] && str[rpos[1]] == str[rpos[2]]){
add(lpos[1], rpos[2]); add(rpos[1], lpos[2]);
}
if (str[lpos[1]] == str[rpos[2]] && str[rpos[1]] == str[lpos[2]]){
add(lpos[1], lpos[2]); add(rpos[1], rpos[2]);
}
}
}
ans = S.size();
if (flag){
for (int i = 0; i < 26; i ++)
if (cnt[i] > 1) ans += 1ll * cnt[i] * (cnt[i] - 1) / 2;
}
printf("%lld\n", ans);
}
return 0;
}