Codeforces Round #846 (Div. 2) G. Delicious Dessert(后缀自动机sam/后缀数组sa)

题目

给定长为n(n<=1e6)的纯小写字母串s,

若s的子串t满足其在s中的出现次数cnt(t),能被t的长度len(t)整除,即cnt(t)%len(t)=0,

则称t是美味的,求s中所有美味子串的数量

注意如果美味串出现了多次,则所有出现位置都被计算在内

即,若美味t出现了cnt(t)次,则其贡献为cnt(t)

思路来源

superguymj代码

洛谷P3804 【模板】后缀自动机 (SAM) (后缀自动机)_Code92007的博客-CSDN博客_洛谷p3804 【模板】后缀自动机

题解

赛中用sa+并查集配合height数组乱搞搞过去了,

赛后发现是sam模板题,来补一下sam

sam做法是统计因数,sa做法是统计倍数

sam题解

与P3804类似,sam中每个节点对应串,其出现次数为parent树上子节点的sz大小之和

先求出每个节点t对应串的出现次数cnt(t),由于节点t对应的串长在(fa(t).len,t.len]之间

即求这段区间内有多少数是cnt(t)的约数,前缀和作差化作[1,len]-[1,fa(t).len]

预处理约数数组,对于[1,x]的询问,在cnt(t)对应约数数组中,二分x的位置

sa题解

height求和前一个排名的lcp的长度,按值降序加入height值

维护并查集,每加入一个height就合并下相邻项

并查集实际维护的是若干条链,链长即串出现的次数

外层for按值域i(也是串长度)降序加入,

内层for枚举i的倍数j,找串出现次数j对应的链有cur[j]条,其贡献是j*cur[j]

cur[i]表示当前长度为i的链,也就是出现次数为i的串有cur[i]个

链长及cur数组,只有在并查集连的时候,会发生变化,最多变化n次

所以,一边合并链,一边枚举倍数,统计答案即可

sam代码

#include<iostream>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
struct SAM{
    static const int N=1e6+10;
    struct NODE{
        int ch[26]; // 每个节点添加一个字符后到达的节点
        int len;  // 每个节点代表的状态中的最大串长
        int fa; // 每个节点在parent树上的父亲节点, 即该节点link边指向的节点
        int sz; // 
        NODE(){memset(ch,0,sizeof(ch));len=0;}
    }dian[N<<1]; // 节点数开串长的两倍
    int n; // 串长 
    char s[N]; // 串
    int las=1; // las: 上一个用到的节点编号
    int tot=1; // tot: 当前用到的节点编号
    // 向SAM中插入一个字符c 
    void add(int c){
        int p=las; // 上一个状态的节点
        int np=las=++tot; // 要加入的状态的节点
        dian[np].sz=1; // 叶子节点endpos大小为1
        dian[np].len=dian[p].len+1; // 新状态比上一个装填多一个首字符
        for(;p&&!dian[p].ch[c];p=dian[p].fa)dian[p].ch[c]=np; // 指针p沿link边回跳,直至找到一个节点包含字符c的出边,无字符c的出边则将出边指向新状态的节点
        if(!p)dian[np].fa=1;// 以上为case 1,指针p到SAM的起点的路径上的节点都无字符c的出边,将新节点作为SAM的起点的一个儿子节点
        else{ // 节点p包含字符c的出边
            int q=dian[p].ch[c]; // 节点p的字符c的出边指向的节点
            if(dian[q].len==dian[p].len+1)dian[np].fa=q;// 以上为case 2,节点p和q代表的状态的最大串长相差1
            else{ // 节点p和q代表的状态的最大串长相差>1
                int nq=++tot; // 新建一个节点nq
                dian[nq]=dian[q]; // 节点nq克隆节点q的信息
                dian[nq].sz=0; // nq产生时,是一个分支节点,需要从后续儿子节点里更新获取sz
                dian[nq].len=dian[p].len+1; // 保证节点p和nq代表的状态的最大串长相差1
                dian[q].fa=dian[np].fa=nq; 
                for(;p&&dian[p].ch[c]==q;p=dian[p].fa)dian[p].ch[c]=nq;// 以上为case 3,将节点p到SAM的起点的路径上的所有节点的字符c的出边指向的节点替换为nq
            }
        }
    }

    vector<int>fac[N];
    void get_fac(){
        for(int i=1;i<N;++i){
            for(int j=i;j<N;j+=i){
                fac[j].push_back(i);
            }
        }
    }
    // 统计<=x的d的约数的个数
    int cal(int d,int x){
        return upper_bound(fac[d].begin(),fac[d].end(),x)-fac[d].begin();
    }

    void init(){
        scanf("%d%s",&n,s);
        for(int i=0;i<n;i++)add(s[i]-'a');
    }
    int b[N<<1],a[N<<1]; // b: 用于基数排序 a: 用于记录点号 
    // 按长度基数排序,短的在前长的在后
    // 另一种方法是用vector直接建出parent树,对parent树直接dfs
    void base_sort(){
        for(int i=1;i<=tot;++i)b[dian[i].len]++;
        for(int i=1;i<=tot;++i)b[i]+=b[i-1];
        for(int i=1;i<=tot;++i)a[b[dian[i].len]--]=i;
    }
    // 逆拓扑序遍历求出sz,每个节点串长(dian[fa].len,dian[u].len],出现次数dian[u].sz,统计dian[u].sz的因数
    void get_sz(){
        long long ans=0;
        for(int i=tot;i>=1;--i){
            int u=a[i],fa=dian[u].fa;
            dian[fa].sz+=dian[u].sz;
            //cout<<"u:"<<u<<" sz:"<<dian[u].sz<<" len:"<<dian[u].len<<" falen:"<<dian[fa].len<<" cnt:"<<cal(dian[u].sz,dian[u].len)-cal(dian[u].sz,dian[fa].len)<<endl;
            ans+=1ll*dian[u].sz*(cal(dian[u].sz,dian[u].len)-cal(dian[u].sz,dian[fa].len));
        }
        cout<<ans<<endl;
    }
    void solve(){
        init();
        get_fac();
        base_sort();
        get_sz();
    }
}sam;
int main(){
    sam.solve();
    return 0;
}

sa代码

#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
using namespace std;
typedef long long ll;
const int maxn=1e6+10;
struct SuffixArray{
    int par[maxn],sz[maxn],cur[maxn];
    char s[maxn];
    ll ans;
    vector<int>add[maxn];
    inline int find(int x){
        return par[x]==x?x:par[x]=find(par[x]);
    }
    inline void un(int x,int y){
        x=find(x),y=find(y);
        if(x==y)return;
        cur[sz[x]]--;
        cur[sz[y]]--;
        par[y]=x;
        sz[x]+=sz[y];
        cur[sz[x]]++;
    }
    typedef long long ll;
    static const int maxn=1e6+10;
    int cnt[maxn],mx,n,rk[maxn],sa[maxn],tmp[maxn],ht[maxn];
    inline void base_sort(){
        memset(cnt,0,sizeof(*cnt)*(mx+1));
        for(int i=1;i<=n;++i)++cnt[rk[i]];
        for(int i=1;i<=mx;++i)cnt[i]+=cnt[i-1];
        for(int i=n;i;--i)sa[cnt[rk[tmp[i]]]--]=tmp[i];
    }
    inline void suffix_sort(){
        mx=0;
        for(int i=1;i<=n;++i)mx=max(mx,rk[i]=s[i]),tmp[i]=i;
        base_sort();
        for(int len=1,dif=0;dif<n;len<<=1,mx=dif){
            int p=0;
            for(int i=n-len+1;i<=n;++i)tmp[++p]=i;
            for(int i=1;i<=n;++i)
                if(sa[i]>len)
                    tmp[++p]=sa[i]-len;
            base_sort();
            swap(rk,tmp);
            rk[sa[1]]=dif=1;
            for(int i=2;i<=n;++i){
                if(tmp[sa[i-1]]!=tmp[sa[i]]||tmp[sa[i-1]+len]!=tmp[sa[i]+len])++dif;
                rk[sa[i]]=dif;
            }
        }
    }
    inline void calc_ht(){
        for(int i=1,h=0;i<=n;++i){
            if(h)--h;
            int j=sa[rk[i]-1];
            while(s[i+h]==s[j+h])++h;
            ht[rk[i]]=h;
        }
    }
    // rank、sa、height下标均为[1,n]
    inline void PR(){
        string p(s+1);
        for(int i=1;i<=n;++i)
        printf("Rank[%d]:%d\n",i,rk[i]);
        for(int i=1;i<=n;++i)
        {
            printf("sa[%d]:%d ",i,sa[i]);
            cout<<p.substr(sa[i]-1)<<endl;
        }
        for(int i=1;i<=n;++i)
        printf("ht[%d]:%d\n",i,ht[i]);
    }
    inline void solve(){
        scanf("%d%s",&n,s+1);
        suffix_sort();
	    calc_ht();
        //PR();
        for(int i=0;i<=n;++i){
            par[i]=i;sz[i]=1;
            if(ht[i]>=1)add[ht[i]].push_back(i);
        }
        cur[1]=n+1;
        for(int i=n;i>1;--i){
            for(auto &x:add[i]){
                un(x-1,x);
            }
            // cout<<"i:"<<i<<endl;
            // for(int j=1;j<=n;++j){
            //     cout<<"cur["<<j<<"]:"<<cur[j]<<endl;
            // }
            for(int j=i;j<=n;j+=i){
                ans+=1ll*j*cur[j];
            }
        }
        cout<<ans+n<<endl;
    }
}sa;
int main(){
    sa.solve();
    return 0;
}

猜你喜欢

转载自blog.csdn.net/Code92007/article/details/128770248