[BZOJ 2061] Country(KMP+记忆化搜索)

[BZOJ 2061] Country(KMP+记忆化搜索)

题面

gaoxin神犇频繁的在发言中表现对伟大,光荣,正确的xx的热爱,我们可以做如下定义: A=伟大,光荣,正确的 B=xx C=引领我们向前 赞美祖国=ABC 拼命赞美祖国=赞美祖国10 gaoxin的发言=拼命赞美祖国100 显然这个定义必须是无环的。 WJMZBMR感到十分的有压力, 他好不容易数出了某个字串的出现次数。。。

某天他打开电视,发现某人的发言有同样的结构。。他很无语。。。想知道某些字串出现的次数。。 你能帮帮他吗?

一些定义: 为了简化期间,在输入中用英文表示.

字符串:由小写字母组成如a

字串名:一定是大写字母如A 那么上面的系统可以写成 A=greatglorycorrect B=xx C=leadusgo D=ABC E=DDDDdjh F=EEEEEgoodbye 同时存在一个母字串名,他就是某人的发言

分析

记忆化搜索,\(dp[i][j]\)表示当前处理到字符串\(i\)(还未与i匹配),上一次模板串匹配到\(j\)位时的匹配次数。

我们遍历i中的每个字符,如果是小写字母,就利用KMP匹配,并记录现在已经匹配到的位置x.如果是大写字母id,就递归dfs(id,x)。为了更新匹配位置,我们要额外维护一个数组\(pos[i][j]\)表示当前处理到字符串\(i\)(还未与i匹配),上一次模板串匹配到\(j\)位,匹配结束后模板串的位置。递归完成后把x赋值为\(pos[id][x]\). 同时更新\(dp[i][j]\). 遍历完每个字符后把\(pos[i][j]\)赋值为x

最终答案为\(dp[s][0]\) (这也是为什么我们定义j表示初始时的匹配位置,要不然我们不知道最后匹配到哪里)

代码

#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
#include<cctype>
#include<queue>
#define maxn 26
#define maxl 100000
#define mod 10000

using namespace std;
int n,s;


inline bool isAlpha(char c){
    return c>='A'&&c<='Z';
}
void get_nex(char *s,int n,int *nex) {
    nex[1]=0;
    for(int i=2,j=0; i<=n; i++) {
        while(j&&s[j+1]!=s[i]) j=nex[j];
        if(s[j+1]==s[i]) j++;
        nex[i]=j;
    }
}

int a[maxn+5][maxl+5];
int len[maxn+5];
char tp[maxn+5],lentp;
int nex[maxl+5];

int dp[maxn+5][maxl+5];
//dp[i][j]表示字符串i,当前还未匹配时末尾已经与模板串匹配了j位,
//串i里包含模板串的个数 
int pos[maxn+5][maxl+5];//pos[i][j]表示匹配结束后KMP指针的位置 
int dfs(int i,int j){
    if(dp[i][j]!=-1) return dp[i][j];
    dp[i][j]=0;
    int x=j;
    for(int k=1;k<=len[i];k++){
        if(isAlpha(a[i][k])){
            int id=a[i][k]-'A'+1;
            dp[i][j]+=dfs(id,x);
            dp[i][j]%=mod;
            x=pos[id][x]; 
        }else{
            int p=x;
            while(p&&tp[p+1]!=a[i][k]) p=nex[p];
            if(tp[p+1]==a[i][k]) p++;
            x=p;
            if(x==lentp){
                dp[i][j]++;
                dp[i][j]%=mod;
            } 
        }
    }
    pos[i][j]=x;
    return dp[i][j];
}

char in[maxl+5];
int main() {
    scanf("%d",&n);
    scanf("%s",in+1);
    s=in[1]-'A'+1;
    for(int i=1; i<=n; i++) {
        scanf("%s",in+1);
        int l=strlen(in+1);
        for(int j=3; j<=l; j++) a[in[1]-'A'+1][j-2]=in[j];
        len[i]=l-2;
    }
    scanf("%s",tp+1);
    lentp=strlen(tp+1); 
    get_nex(tp,lentp,nex);
    memset(dp,0xff,sizeof(dp));
    dfs(s,0); 
    printf("%d\n",dp[s][0]);
}

猜你喜欢

转载自www.cnblogs.com/birchtree/p/12172306.html