HDU - 6194 string string string (后缀数组)

题意:

告诉你一个字符串和k , 求这个字符串中有多少不同的子串恰好出现了k 次。

思路:

肯定先考虑出现了K次的情况,显然是截取长度为K的一段后缀,根据lcp 数组,可以计算得到有多少个后缀至少出现K次,注意是至少。

那么就存在一个问题,就是出现大于K次的字符串也算进来了。

有一个很巧妙的办法,就是查找出现K+1的字符串,所以就是sa[i-1] ~ sa[i + k - 1]  和 sa[i] ~ sa[i + k] 求两次lcp 减去即可。

但是会多减,并且多减的肯定是sa[i-1] ~ sa[i + k] 的lcp。 加上即可。

不过为什么这么做呢

下面这组数据为例:

k=3

后缀字符串为以下:

abc

abcabc

abcabcabc

abcabcabcabc

abcabcabcabcabc

以中间三行为例,上下同时减,是为了看上边最长公共前缀对答案有无贡献,如果有的话,就直接减去,例如第一行的abc在中间三行又出现了,所以不必再进行计算。再看下边最长公共前缀对答案有无贡献,如果有,也是直接减去即可,例如下边最长公共前缀为abcabc,所以也是不必计算,直接剪掉即可,但同时会发现,两部分多减了一个abc,所以再加上即可。

虽然还不知到这东西的理论性,但感觉冥冥之中自有定论,就感觉这么算是对的,一直减去k+1次出现的,正好就是答案。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>

const int MAXN = 1e5+10;
int r[MAXN];  // r 数组保存了字符串中的每个元素值,除最后一个元素外,每个元素的值在 1..m 之间,最后一个元素的值为 0
int wa[MAXN], wb[MAXN], wv[MAXN], ws[MAXN];  // 这 4 个数组是后缀数组计算时的临时变量,无实际意义
int sa[MAXN]; //  sa[i] 保存第 i 小的后缀在字符串中的开始下标,i 取值范围为 0..n-1
int cmp(int *r, int a, int b, int l) {
    return r[a] == r[b] && r[a + l] == r[b + l];
}
void da(int *r, int *sa, int n, int m) {  // n 为字符串的长度,注意是长度,m 为字符最大值
    int i, j, p, *x = wa, *y = wb;
    for (i = 0; i < m; ++i) ws[i] = 0;
    for (i = 0; i < n; ++i) ws[x[i] = r[i]]++;
    for (i = 1; i < m; ++i) ws[i] += ws[i - 1];
    for (i = n - 1; i >= 0; --i) sa[--ws[x[i]]] = i;
    for (j = 1, p = 1; p < n; j *= 2, m = p) {
        for (p = 0, i = n - j; i < n; ++i) y[p++] = i;
        for (i = 0; i < n; ++i) if (sa[i] >= j) y[p++] = sa[i] - j;
        for (i = 0; i < n; ++i) wv[i] = x[y[i]];
        for (i = 0; i < m; ++i) ws[i] = 0;
        for (i = 0; i < n; ++i) ws[wv[i]]++;
        for (i = 1; i < m; ++i) ws[i] += ws[i - 1];
        for (i = n - 1; i >= 0; --i) sa[--ws[wv[i]]] = y[i];
        for (std::swap(x, y), p = 1, x[sa[0]] = 0, i = 1; i < n; ++i)
            x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p++;
    }
    return;
}

int rank[MAXN];  // rank[i] 表示从下标 i 开始的后缀的排名,值为 1..n
int height[MAXN]; // 下标范围为 1..n,height[1] = 0
void calHeight(int *r, int *sa, int n) {        //n为字符串长度减一
    int i, j, k = 0;
    for (i = 1; i <= n; ++i)
        rank[sa[i]] = i;
    for (i = 0; i < n; height[rank[i++]] = k)
        for (k ? k-- : 0, j = sa[rank[i] - 1]; r[i + k] == r[j + k]; ++k);
    return;
}

char s[MAXN];

int lg[MAXN];
int ST[30][MAXN];

void get_rmq(int n){
    lg[0]=-1;
    for(int i=1;i<=n;i++){
        lg[i]=lg[i/2]+1;
    }
    for(int i=1;i<=n;i++){
        ST[0][i]=height[i];
    }
    for(int i=1;i<=lg[n];i++){
        for(int j=1;j+(1<<i)-1<=n;j++){
            ST[i][j]=std::min(ST[i-1][j],ST[i-1][j+(1<<(i-1))]);
        }
    }
}

int n;

int query(int x,int y){
    if(x==y) return n-sa[x];
    x++;
    int t=lg[y-x+1];
    return std::min(ST[t][x],ST[t][y-(1<<t)+1]);
}

int main(){
    int T;
    scanf("%d",&T);
    while(T--){
        int k;
        scanf("%d%s",&k,s);
        n=strlen(s);
        for(int i=0;i<n;i++){
            r[i]=s[i];
        }
        r[n]=0;
        da(r,sa,n+1,128);
        calHeight(r,sa,n);
        get_rmq(n);
        long long ans=0;
        for(int i=1;i+k-1<=n;i++){
            ans+=query(i,i+k-1);
            if(i-1>0) ans-=query(i-1,i+k-1);
            if(i+k<=n) ans-=query(i,i+k);
            if(i-1>0&&i+k<=n) ans+=query(i-1,i+k);
        }
        printf("%lld\n",ans);
    }
}

猜你喜欢

转载自blog.csdn.net/qq_40679299/article/details/82763711