[Nowcoder 2018ACM多校第四场H] Double Palindrome

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013578420/article/details/81330290

题目大意:
给你一个字符串s, 求有多少对(i, j), ( 1 i < j | s | ) , 使得交换s[i]和s[j]后, s可以被分割成两个回文串。 ( 1 | s | 10 5 , | s | 10 6 )

题目思路:
可以暴力的枚举从哪个点可以分割。
考虑[1, i] 和 [i+1,|s|]这一段。
首先可以用后缀数组, 求出s+rev(s)的后缀数组。
这样就可以O(1)的得出某两个位置往中间匹配的最长公共长度。
如果匹配长度大于了[l,r]的区间长, 说明这个子串是回文串。
如果两个串都是回文串, 有个标记记录一下, 最后的时候加上所有相同字符交换的答案。
否则, 记录下失配的两个端点, 跳过这两个端点继续匹配, 记录下各个失配的端点。
若果最终失配的端点组数大于2, 则是没救的, continue。
如果失配的端点组数等于2, 则必须得靠某组的一个端点与另一个组的一个端点交换。
若果失配的端点组数为1, 则必须靠某一个奇数串的中心来交换。

需要注意的是, 在两个串都是回文串的情况下, 即失配的端点组数为0, 依然有可能是两个奇数串的中心交换。

对于相同字符的交换情况, 可以用组合数直接算, 不同字符交换情况, 是O(n)的, 可以每次丢到set里暴力去重。

Code:

#include <map>
#include <set>
#include <map>
#include <bitset>
#include <cmath>
#include <queue>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>

#define ll long long
#define db double
#define fi first
#define se second
#define mp(x, y) make_pair(x, y)
#define ls (x << 1)
#define rs ((x << 1) | 1)
#define mid ((l + r) >> 1)

using namespace std;

const int N = (int)2e5 + 10;

int Log[N * 2];

namespace SA{
    const int M = 30;
    char str[N * 2]; int n, m;
    int c[N * 2], x[N * 2], y[N * 2], height[N * 2], rk[N * 2], sa[N * 2], st[N * 2][M];
    void build_sa(){
        m = 26;
        for (int i = 1; i <= n; i ++) x[i] = str[i] - 'a' + 1;
        for (int i = 1; i <= n; i ++) c[x[i]] ++;
        for (int i = 2; i <= m; i ++) c[i] += c[i - 1];
        for (int i = 1; i <= n; i ++) sa[c[x[i]] --] = i;
        for (int i = 1; i <= m; i ++) c[i] = 0;

        for (int k = 1; k < n; k <<= 1){

            int p = 0;
            for (int i = n - k + 1; i <= n; i ++) y[++ p] = i;
            for (int i = 1; i <= n; i ++) if (sa[i] > k) y[++ p] = sa[i] - k;
            for (int i = 1; i <= n; i ++) c[x[i]] ++;
            for (int i = 2; i <= m; i ++) c[i] += c[i - 1];
            for (int i = n; i >= 1; i --) sa[c[x[y[i]]] --] = y[i];
            for (int i = 1; i <= m; i ++) c[i] = 0;
            for (int i = 1; i <= n; i ++) y[i] = x[i];
            p = 1;

            x[sa[1]] = 1;
            for (int i = 2; i <= n; i ++)
                x[sa[i]] = y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k] ? p : ++ p;
            for (int i = 1; i <= n; i ++) y[i] = 0;

            if (p >= n) break;
            m = p;
        }

    }
    void call_height(){
        for (int i = 1; i <= n; i ++) rk[sa[i]] = i, height[i] = 0;
        int k = 0;
        for (int i = 1; i <= n; i ++){
            if (rk[i] == 1) continue;
            for (k ? k -- : 0; str[i + k] == str[sa[rk[i] - 1] + k] && i + k <= n && sa[rk[i] - 1] + k <= n; k ++);
            height[rk[i]] = k;
        }
    }
    void build_st(){
        for (int i = 1; i <= n; i ++)
            st[i][0] = height[i];
        for (int j = 1; j < M; j ++)
            for (int i = 1; i <= n; i ++){
                if (i + (1 << (j - 1)) > n) break;
                st[i][j] = min(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
            }
    }
    int lcp(int l, int r){
        l = rk[l], r = rk[r];
        if (l > r) swap(l, r);
        l ++;
        int k = Log[r - l + 1];
        r = r - (1 << k) + 1;
        return min(st[l][k], st[r][k]);
    }

    void main(char *s, int l){
        for (int i = 1; i <= l; i ++)
            str[i] = s[i];
        for (int i = 1; i <= l; i ++)
            str[l + i] = s[l - i + 1];
        n = 2 * l;
        str[n + 1] = 0;
        build_sa();
        call_height();
        build_st();
    }
}

int T;
char str[N]; int n, cnt[30];
ll ans; bool flag; set<pair<int, int> > S;

int rev(int i){return n * 2 - i + 1;}
void getpos(int l, int r, int &sz, int *lpos, int *rpos){
    while (sz <= 2 && l < r){
        int len = SA::lcp(l, rev(r));
        if (len >= r - l + 1) break;
        sz ++;
        lpos[sz] = l + len, rpos[sz] = r - len;
        l += (len + 1), r -= (len + 1);
    }
}
void add(int i, int j){
    if (i > j) swap(i, j);

    S.insert(mp(i, j));
}

int main(){

    for (int i = 1; (1 << i) < 2 * N; i ++)
        Log[1 << i] = 1;

    for (int i = 1; i < N * 2; i ++)
        Log[i] += Log[i - 1];

    scanf("%d\n", &T);
    while (T --){

        scanf("%s\n", str + 1);
        n = strlen(str + 1);

        flag = 0;
        ans = 0; S.clear();
        memset(cnt, 0, sizeof(cnt));
        for (int i = 1; i <= n; i ++)
            cnt[str[i] - 'a'] ++;

        SA::main(str, n);
        for (int i = 1; i <= 2 * n; i ++)
            printf("%d ", SA::height[i]);
        puts("");

        for (int i = 1; i <= 2 * n; i ++)
            printf("%d ", SA::sa[i]);
        puts("");


        for (int i = 1; i < n; i ++){
            //[1, i] [i + 1, n]
            int sz = 0, lpos[5], rpos[5];

            //printf("%d\n", i);
            getpos(1, i, sz, lpos, rpos);
            getpos(i + 1, n, sz, lpos, rpos);

            if (sz == 0) {
                flag = 1;
                if ((i & 1) && ((n - i) & 1)){
                    if (str[i / 2 + 1] != str[i + (n - i) / 2 + 1])
                        add(i / 2 + 1, i + (n - i) / 2 + 1);
                }
                continue;
            }
            if (sz > 2) continue;

            if (sz == 1){
                if (i & 1){
                    if (str[i / 2 + 1] == str[lpos[1]]) add(i / 2 + 1, rpos[1]);
                    if (str[i / 2 + 1] == str[rpos[1]]) add(i / 2 + 1, lpos[1]);
                }
                if ((n - i) & 1){
                    if (str[i + (n - i) / 2 + 1] == str[lpos[1]]) add(i + (n - i) / 2 + 1, rpos[1]);
                    if (str[i + (n - i) / 2 + 1] == str[rpos[1]]) add(i + (n - i) / 2 + 1, lpos[1]);
                }

                continue;
            }

            if (sz == 2){
                if (str[lpos[1]] == str[lpos[2]] && str[rpos[1]] == str[rpos[2]]){
                    add(lpos[1], rpos[2]); add(rpos[1], lpos[2]);
                }
                if (str[lpos[1]] == str[rpos[2]] && str[rpos[1]] == str[lpos[2]]){
                    add(lpos[1], lpos[2]); add(rpos[1], rpos[2]);
                }
            }

        }

        ans = S.size();
        if (flag){
            for (int i = 0; i < 26; i ++)
                if (cnt[i] > 1) ans += 1ll * cnt[i] * (cnt[i] - 1) / 2;
        }

        printf("%lld\n", ans);
    }

    return 0;
}

猜你喜欢

转载自blog.csdn.net/u013578420/article/details/81330290