3473 字符串 - 后缀自动机 - 线段树合并

往上题解很多复杂度其实是和串总长有关的,如果给你一颗Trie就gg了。SAM上一个点的答案是其parent树中所有点的答案的并集的结论是不对的,因为有可能当前这个节点可以表示一条从根出发到某个点的路径,这条路径没有后继,但是一个节点对应的这样的路径显然不超过一个,因此其线段树初始化为相应区间即可,最后跑一边线段树合并即可,线段树合并的时候遇到一个满的区间就return掉。另,这个题求得是重复子串算多次,所以没有问题,否则还要写一个很麻烦的树剖来保证复杂度是和Trie的大小而不是字符串总长有关的复杂度。虽然BB了这么多但是这个题字符串总长能过,以及我并不知道这样线段树复杂度是什么,只知道不会比两个log差,不会比一个log优,跑的还算快。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<queue>
#include<vector>
#define is_full(x) (t[x].len==t[x].v)
#define lint long long
#define N 100010
#define LEN 100010
#define TRIE_SIZE LEN
#define SAM_SIZE TRIE_SIZE*2
#define SIG 26
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
int pos[N],dfsc;
struct edges{
    int to,pre;
}e[SAM_SIZE];int h[SAM_SIZE],etop,val[SAM_SIZE];
inline int add_edge(int u,int v)
{   return e[++etop].to=v,e[etop].pre=h[u],h[u]=etop;   }
struct segment{
    int len,v,ch[2];
    segment(int _len=0,int _v=0,int _ch0=0,int _ch1=0)
    {   len=_len,v=_v,ch[0]=_ch0,ch[1]=_ch1;    }
    inline segment operator=(const segment &s)
    {   return len=s.len,v=s.v,ch[0]=s.ch[0],ch[1]=s.ch[1],*this;   }
};vector<segment> seg;int s[SAM_SIZE],scnt;
inline int push_up(int x)
{   return seg[x].v=seg[seg[x].ch[0]].v+seg[seg[x].ch[1]].v;    }
inline int new_segment(int len)
{   return seg.push_back(segment(len)),(int)seg.size()-1;   }
int build(int l,int r,int s,int t)
{
    int x=new_segment(r-l+1),mid=(l+r)>>1,p;
    if(s<=l&&r<=t) return seg[x].v=r-l+1,x;
    if(s<=mid) p=build(l,mid,s,t),seg[x].ch[0]=p;
    if(mid<t) p=build(mid+1,r,s,t),seg[x].ch[1]=p;
    return push_up(x),x;
}
#define t seg
int merge_seg(int x,int y,int l,int r)
{
    if(!x||!y) return x+y;if(is_full(x)) return x;if(is_full(y)) return y;
    int z=new_segment(t[x].len),a=0,b=0,mid=(l+r)>>1,lc=0,rc=0;
    ((!t[x].ch[0]||!t[y].ch[0])?a=t[x].ch[0]+t[y].ch[0],lc=1:0);
    ((!t[x].ch[1]||!t[y].ch[1])?b=t[x].ch[1]+t[y].ch[1],rc=1:0);
    ((!lc&&is_full(t[x].ch[0]))?a=t[x].ch[0],lc=1:0),((!lc&&is_full(t[y].ch[0]))?a=t[y].ch[0],lc=1:0);
    ((!rc&&is_full(t[x].ch[1]))?b=t[x].ch[1],rc=1:0),((!rc&&is_full(t[y].ch[1]))?b=t[y].ch[1],rc=1:0);
    (lc?0:a=merge_seg(t[x].ch[0],t[y].ch[0],l,mid),lc=1),(rc?0:b=merge_seg(t[x].ch[1],t[y].ch[1],mid+1,r),rc=1);
    return (lc?seg[z].ch[0]=a:0),(rc?seg[z].ch[1]=b:0),push_up(z),z;
}
#undef t
struct SAM{
    int v,ch[SIG],fa;
}t[SAM_SIZE];int sam_rt,sam_cnt;
inline int new_sam_node(int v)
{   return t[++sam_cnt].v=v,sam_cnt;    }
inline int extend(int w,int p,int l,int r)//return np
{
    int np=new_sam_node(t[p].v+1);
    while(p&&!t[p].ch[w]) t[p].ch[w]=np,p=t[p].fa;
    if(!p) t[np].fa=sam_rt;
    else{
        int q=t[p].ch[w];
        if(t[q].v==t[p].v+1) t[np].fa=q;
        else{
            int nq=new_sam_node(t[p].v+1);
            memcpy(t[nq].ch,t[q].ch,sizeof(t[q].ch));
            t[nq].fa=t[q].fa,t[q].fa=t[np].fa=nq;
            while(p&&t[p].ch[w]==q) t[p].ch[w]=nq,p=t[p].fa;
        }
    }
//  debug(np)sp,debug(l)sp,debug(r)ln;
    return s[np]=build(1,dfsc,l,r),np;
}
inline int get_parent()
{
    for(int i=1;i<=sam_cnt;i++)
        if(t[i].fa) add_edge(t[i].fa,i);
    return 0;
}
int merge_all(int x,int k)
{
    for(int i=h[x],y;i;i=e[i].pre)
        merge_all(y=e[i].to,k),s[x]=merge_seg(s[x],s[y],1,dfsc);
    return val[x]=(seg[s[x]].v>=k)*(t[x].v-t[t[x].fa].v);
}
int get_val(int x=sam_rt,int f=0)
{
    val[x]+=val[f];
    for(int i=h[x];i;i=e[i].pre) get_val(e[i].to,x);
    return 0;
}
struct Trie{
    int ch[SIG],wc;
}trie[TRIE_SIZE];lint ans[TRIE_SIZE];
int trie_rt,las[TRIE_SIZE],trie_cnt;
int L[TRIE_SIZE],R[TRIE_SIZE];queue<int> q;
#define t trie
inline int insert_trie(char *s,int n,int x=trie_rt)
{
    for(int i=1;i<=n;i++)
    {
        int c=s[i]-'a';
        if(!t[x].ch[c]) t[x].ch[c]=++trie_cnt;
        x=t[x].ch[c];
    }
    return ++t[x].wc,x;
}
int dfs_trie(int x=trie_rt)
{
    L[x]=dfsc+1,dfsc+=t[x].wc;
    for(int i=0,y;i<SIG;i++)
        if((y=t[x].ch[i])) dfs_trie(y);
    return R[x]=dfsc;
}
inline int get_sam()
{
    for(q.push(trie_rt);!q.empty();q.pop())
        for(int i=0,c,x=q.front();i<SIG;i++)
            if((c=t[x].ch[i])) las[c]=extend(i,las[x],L[c],R[c]),q.push(c);
    return 0;
}
int get_ans(int x=trie_rt,int f=0)
{
    ans[x]=ans[f]+val[las[x]];
    for(int i=0,c;i<SIG;i++)
        if((c=t[x].ch[i])) get_ans(c,x);
    return 0;
}
#undef t
char str[LEN];
int main()
{
//  freopen("data.in","r",stdin);
//  freopen("std.out","w",stdout);
    int n,k;seg.push_back(segment()),las[trie_rt=++trie_cnt]=sam_rt=new_sam_node(0),scanf("%d%d",&n,&k);
    for(int i=1;i<=n;i++) scanf("%s",str+1),pos[i]=insert_trie(str,(int)strlen(str+1));
    dfs_trie(),get_sam(),get_parent(),merge_all(sam_rt,k),get_val(),get_ans();
    for(int i=1;i<=n;i++) printf("%lld ",ans[pos[i]]);return 0;
}


猜你喜欢

转载自blog.csdn.net/mys_c_k/article/details/80114740