Luogu-3250 [BJOI2017]魔法咒语(AC自动机,矩阵快速幂)

Luogu-3250 [BJOI2017]魔法咒语(AC自动机,矩阵快速幂)

题目链接

题解:

多串匹配问题,很容易想到是AC自动机
先构建忌讳词语的AC自动机,构建时顺便记录一下这个点以及它的所有后缀有没有忌讳词语,即对于每个AC自动机上的结点\(x\),\(p[x].p|=p[p[x].f].p\)
然后前半部分分和后半是两道完全不同的题目(滑稽

前60分:

这些部分分的特征是\(L\le 100\)
直接AC自动机上\(dp\)就好了,枚举匹配长度\(i\),当前匹配到的点\(x\),以及后面要匹配的基本词汇\(s[j]\),找到匹配了这个串后到达的点\(y\),如果匹配过程中没有经过忌讳词语结点,就进行转移\(f[i+len[s[j]]][y]+=f[i][x]\)
最后统计下终点在每个点的情况就好了

inline void work(){
    memset(f,0,sizeof(f));
    f[0][0]=1;
    for(int i=0;i<L;i++)
        for(int x=0;x<=tot;x++){
            for(int j=1;j<=n;j++){
                int len=strlen(a[j]+1);
                if(i+len>L) continue;
                int y=run(x,a[j]);
                if(y==-1) continue;
                (f[i+len][y]+=f[i][x])%=P;
            }
        }
    int Ans=0;
    for(int x=0;x<=tot;x++) (Ans+=f[L][x])%=P;
    printf("%d\n",Ans);
}

后40分:

这些测试点的特点就是基本词汇长度小于等于2
\(\sum s[i]\)这么小,\(L\)这么大不禁让人往矩阵快速幂上想
利用矩阵乘法进行转移,如果设被乘的矩阵\(S\)\(S[0][i]\)代表每个点是否可以在串长为0时匹配(很明显只有\(S[0][0]=1\))。那么如果我们构建乘它的矩阵为\(G\),根据矩阵乘法的运算法则:
\[ T[0][j]=\sum_{k=0}^n{S[0][k]*G[k][j]} \]
可以发现,如果\(G[x][y]\)代表的是走一步从\(x\)转移到\(y\)的方案数,得到的矩阵\(T\)中的元素\(T[0][x]\)就是走一步转移到\(x\)的方案数,也就是相当于模拟走了一步。如果我们走了\(L\)步让\(S*G^L\)就好了,矩阵乘法有结合律,我们就可以先用快速幂算出\(G^L\),最后再乘\(S\)

但是这样做只能处理串长为\(1\)的情况,因为没有保留前一步的方案数,串长为\(2\)的转移是无法处理的。也就是说,设当前是第\(k\)步,转移应该包括两部分:一部分是从\(S_k\)走一步到\(S_{k+1}\),一部分是从\(S_{k-1}\)走两步到\(S_{k+1}\)

这种需要同时考虑这一步和上一步的矩阵要如何转移呢?

举一个非常简单的例子吧:求斐波那契数列的第\(n\)项。
由于\(Fib[i+1]=Fib[i-1]+Fib[i]\),我们就用了一个大小为\(1*2\)的矩阵\(S\)\(S[0][0]\)记录\(Fib[i-1]\)\(S[0][1]\)记录\(Fib[i]\)
转移矩阵\(G\)是这样的(矩阵下标从零开始):
\[ G= \left[ \begin{matrix} 0&1\\ 1&1 \end{matrix} \right] \]
联系\(S[0][0]与S[0][1]\)的含义和矩阵乘法的含义:
\(G[1][0]=1\)代表下一步的\(Fib[i-1]\)就是这一步的\(Fib[i]\)
\(G[1][1]=1\)代表下一步的\(Fib[i]\)能由这一步的\(Fib[i]\)转移
\(G[0][1]=1\)代表下一步的\(Fib[i]\)能由这一步的\(Fib[i-1]\)转移
\(G[0][0]=0\)是因为下一步的\(Fib[i-1]\)已经钦定为这一步的\(Fib[i]\),无需其他转移
总的来说,转移矩阵包括四部分:这一个到下一个(走一步),前一个到下一个(走两步),这一个到“下一个的前一个”(单位矩阵),前一个到下一个的前一个(一般为全\(0\)

了解了矩阵乘法求\(Fib[n]\)的原理之后,这道题就很简单了

构造大小为\(2*n\)\(S\)\(S[0][0\sim n]\)记前一步的方案数,\(S[0][n+1\sim n+1+n]\)记这一步的方案数
构造大小为\(2n*2n\)\(G\),右下是长为\(1\)串的转移矩阵,右上是长为\(2\)串的转移矩阵,左下是单位矩阵,左上是全\(0\)矩阵
最后得到的矩阵\(T=S*G^L\)中,\(T[0][n+1\sim n+1+n]\)的和即为答案。
代码:

  • 处理长度分别为\(1,2\)的串的转移
for(int x=0;x<=tot;x++){
        for(int i=1;i<=n;i++){
            if(strlen(a[i]+1)>1) continue;
            int y=run(x,a[i]);
            if(y==-1) continue;
            g.a[tot+1+x][tot+1+y]++;
        }
    }
    for(int x=0;x<=tot;x++){
        for(int i=1;i<=n;i++){
            if(strlen(a[i]+1)<2) continue;
            int y=run(x,a[i]);
            if(y==-1) continue;
            g.a[x][tot+1+y]++;
        }
    }
  • 左下角单位矩阵
    for(int x=0;x<=tot;x++)
        g.a[tot+1+x][x]++;
  • 矩阵快速幂
    g=poww(g,L);
    int Ans=0;
    for(int x=0;x<=tot;x++)
        (Ans+=g.a[tot+1][tot+1+x])%=P;
    printf("%d\n",(Ans+P)%P);

代码:

#include<map>
#include<set>
#include<cmath>
#include<ctime>
#include<queue>
#include<stack>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define qmax(x,y) (x=max(x,y))
#define qmin(x,y) (x=min(x,y))
#define mp(x,y) make_pair(x,y)
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
inline int read(){
    int ans=0,fh=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-') fh=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
        ans=ans*10+ch-'0',ch=getchar();
    return ans*fh;
}
const int maxn=150,P=1e9+7;
struct node{
    int son[26],p,f;
}p[maxn];
struct matrix{
    int a[210][210];
}g,base,tmp;
int tot,n,m,L,f[200][maxn];
char a[60][maxn],b[60][maxn];
matrix operator * (matrix x,matrix y){
    for(int i=0;i<=tot+1+tot;i++)
        for(int j=0;j<=tot+1+tot;j++){
            tmp.a[i][j]=0;
            for(int k=0;k<=tot+1+tot;k++)
                (tmp.a[i][j]+=1ll*x.a[i][k]*y.a[k][j]%P)%=P;
        }
    return tmp;
}
inline void insert(char *s){
    int len=strlen(s+1),x=0;
    for(int i=1;i<=len;i++){
        int z=s[i]-'a';
        if(!p[x].son[z])
            p[x].son[z]=++tot;
        x=p[x].son[z];
    }
    p[x].p=1;
}
queue<int>q;
inline void build(){
    for(int i=0;i<26;i++)
        if(p[0].son[i]) q.push(p[0].son[i]);
    while(!q.empty()){
        int x=q.front();q.pop();
        for(int i=0;i<26;i++)
            if(p[x].son[i]){
                p[p[x].son[i]].f=p[p[x].f].son[i];
                q.push(p[x].son[i]);
            }
            else p[x].son[i]=p[p[x].f].son[i];
        p[x].p|=p[p[x].f].p;
    }
}
inline int run(int x,char *s){
    int len=strlen(s+1);
    for(int i=1;i<=len;i++){
        int z=s[i]-'a';
        x=p[x].son[z];
        if(p[x].p) return -1;
    }
    return x;
}
inline void work(){
    memset(f,0,sizeof(f));
    f[0][0]=1;
    for(int i=0;i<L;i++)
        for(int x=0;x<=tot;x++){
            for(int j=1;j<=n;j++){
                int len=strlen(a[j]+1);
                if(i+len>L) continue;
                int y=run(x,a[j]);
                if(y==-1) continue;
                (f[i+len][y]+=f[i][x])%=P;
            }
        }
    int Ans=0;
    for(int x=0;x<=tot;x++) (Ans+=f[L][x])%=P;
    printf("%d\n",Ans);
}
inline matrix poww(matrix x,int y){
    for(int i=0;i<=tot+1+tot;i++)
        base.a[i][i]=1;
    while(y){
        if(y&1) base=base*x;
        x=x*x,y>>=1;
    }
    return base;
}
inline void work2(){
    for(int x=0;x<=tot;x++){
        for(int i=1;i<=n;i++){
            if(strlen(a[i]+1)>1) continue;
            int y=run(x,a[i]);
            if(y==-1) continue;
            g.a[tot+1+x][tot+1+y]++;
        }
    }
    for(int x=0;x<=tot;x++){
        for(int i=1;i<=n;i++){
            if(strlen(a[i]+1)<2) continue;
            int y=run(x,a[i]);
            if(y==-1) continue;
            g.a[x][tot+1+y]++;
        }
    }
    for(int x=0;x<=tot;x++)
        g.a[tot+1+x][x]++;
    g=poww(g,L);
    int Ans=0;
    for(int x=0;x<=tot;x++)
        (Ans+=g.a[tot+1][tot+1+x])%=P;
    printf("%d\n",(Ans+P)%P);
}
int main(){
//  freopen("nh.in","r",stdin);
//  freopen("zhy.out","w",stdout);
    n=read(),m=read(),L=read();
    for(int i=1;i<=n;i++) scanf("%s",a[i]+1);
    for(int i=1;i<=m;i++) scanf("%s",b[i]+1);
    for(int i=1;i<=m;i++) insert(b[i]);
    build();
    if(L<=100) work();
    else work2();
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/nianheng/p/10504158.html
今日推荐