【Newcoder】2020牛客暑期多校训练营(第二场)A - All with Pairs | KMP、字符串Hash

题目链接:https://ac.nowcoder.com/acm/contest/5667/A

题目大意:

给出N个串,你需要求出每个串和其他串的最长相同前后缀

 \large ans = \sum_i{\sum_j{f(i,j)^2}}

f(i,j)表示i的前缀与j的后缀最大公共长度

题目思路:

考虑到后缀的数量是1e6级别,所以可以先把所有后缀的哈希值储存起来

之后考虑遍历每个串的前缀获得答案

但是此时有一个问题:

aba如果被匹配意味着a也会被匹配一次,所以此时的贡献需要去重

如何去重?

就是kmp的fail树了(此时应该说是链)

考虑只保留最长的所以说随着长度的增加能匹配的数量绝对会减少,匹配到pos位置的最大值真正的数量其实要减去他的下一位匹配的数量,例如:

aba ,aba就对a产生了贡献,但此时ababa也会对a产生贡献,怎么办呢?直接减去aba的贡献即可,因为长度的原因,ababa肯定会对aba产生贡献,并且对aba和a的贡献相同,所以a只需要减去aba贡献即可。

所以说:\large xt[nxt[k]] -= xt[k]

最后计算贡献即可

Note:被莫名卡了哈希

Code:

/*** keep hungry and calm CoolGuang!***/
#pragma GCC optimize(3)
#include <bits/stdc++.h>
#include<stdio.h>
#include<queue>
#include<algorithm>
#include<string.h>
#include<iostream>
#define rep(i,n) for(int i=1;i<=(n);i++)
#define debug(x) cout<<#x<<":"<<x<<endl;
#define _CRT_SECURE_NO_WARNINGS
#pragma GCC optimize("Ofast","unroll-loops","omit-frame-pointer","inline")
#pragma GCC option("arch=native","tune=native","no-zero-upper")
#pragma GCC target("avx2")
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pp;
const ll INF=1e17;
const int Maxn=2e7+10;
const int maxn =1e6+10;
const int mod=998244353;
const int Mod = 1e9+7;
///const double eps=1e-10;
inline bool read(ll &num)
{char in;bool IsN=false;
    in=getchar();if(in==EOF) return false;while(in!='-'&&(in<'0'||in>'9')) in=getchar();if(in=='-'){ IsN=true;num=0;}else num=in-'0';while(in=getchar(),in>='0'&&in<='9'){num*=10,num+=in-'0';}if(IsN) num=-num;return true;}
ll n,m,p;
map<ull,ll>mp;
char str[maxn];
string s[maxn];
int nxt[maxn];
ll xt[maxn];
void getnxt(string p){
    nxt[0] = -1;
    int sz = p.size();
    for(int i=1,j=-1;i<sz;i++){///j代表匹配到的位置
        while(j!=-1&&p[i] != p[j+1]) j = nxt[j];
        if(p[i] == p[j+1]) j++;
        nxt[i] = j;
    }
}
ll cal(ll x){
    return (x*x)%mod;
}
int main(){
    ull base = 97;
    read(n);
    for(int i=1;i<=n;i++){
        scanf("%s",str+1);
        int len = strlen(str+1);
        for(int k=1;k<=len;k++) s[i].push_back(str[k]);
        ull p = 1,now = 0;
        for(int k=len;k>=1;k--){
            now += p*(str[k]-'a'+1);
            p*=base;mp[now]++;
        }
    }
    ll ans = 0;
    for(int i=1;i<=n;i++){
        getnxt(s[i]);
        int sz = s[i].size();
        ull now = 0;
        for(int k=0;k<sz;k++){
            now = now*base + s[i][k]-'a'+1;
            xt[k] = mp[now];
            if(nxt[k]!=-1) xt[nxt[k]]-=xt[k];
        }
        for(ll k=0;k<sz;k++){
            ans += (cal(k+1)*xt[k])%mod;
            ans%=mod;
        }
    }
    printf("%lld\n",ans);
    return 0;
}
/**
1 1 3 4 5 9

1 1 3 4 5 9
**/


猜你喜欢

转载自blog.csdn.net/qq_43857314/article/details/107442747
今日推荐