JZOJ 5711. 【北大夏令营2018模拟5.13】时间幻象

Description

Description

Input

从文件 return.in 中读入数据。
输入文件的第1行2个整数n,m,表示时间线的长度和转移的个数。
后面m行每行2个空格隔开的字符串A? ,B? ,描述一组转移。

Output

输出到文件 return.out 中。
输出文件仅1行,1个非负整数,表示答案。

Sample Input

2 3
A B
A C
D D

Sample Output

5
解释:一种最优方案:AA → AB → BA → BC → CB。

Data Constraint

Data Constraint

Solution

  • 我们发现其实所谓字符串转换只跟每个字母的个数有关,因为相邻两个字母可以交换。

  • 首先我们要意识到一点:状态数和方案数都不会很多!!!

  • 状态数只有: C 30 + 4 1 4 1 = C 33 3 = 5456 (用挡板问题解决)。

  • 而答案也不会很多, C 30 15 也没多大(字母数总和最多才30),并不会爆 long long 。

  • 对于一个状态(设其A-D的个数分别是 p 1 , p 2 , p 3 , n p 1 p 2 p 3 ),

  • 那么它可以转换成的不同排列个数就是:

    C n p 1 C n p 1 p 2 C n p 1 p 2 p 3 C n p 1 p 2 p 3 n p 1 p 2 p 3

  • 我们将这个值设为这个状态的权值。

  • 接着我们枚举那 m 中转移方式,将每种状态转移出去并连边(反正状态数又不多)。

  • 做一遍 tarjan 缩环,一个强连通分量的权值就是其中所有点权之和。

  • 这就变成一个点带权的DAG了,我们需要求一条最长路径,

  • 那么直接像拓扑排序一样DP一遍就好了。

  • 时间复杂度约为 O ( C 33 3 m )

Code

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long LL;
const int N=32,M=6000;
struct data
{
    int x,y;
}edge[M*N];
int n,m,tot,top,num,cnt;
LL ans;
int first[M],nex[M*N],en[M*N];
int a[N<<1][5],b[N<<1][5],c[5],d[5];
int name[N][N][N],q[M],deg[M];
int dfn[M],low[M],st[M],col[M];
bool bz[M];
LL val[M],val1[M],g[N][N],f[M];
char s[N],t[N];
inline int min(int x,int y)
{
    return x<y?x:y;
}
inline int get(int *pos)
{
    return name[pos[1]][pos[2]][pos[3]];
}
inline LL calc(int *pos)
{
    return g[n][pos[1]]*g[n-pos[1]][pos[2]]*g[n-pos[1]-pos[2]][pos[3]];
}
inline void insert(int x,int y)
{
    nex[++tot]=first[x];
    first[x]=tot;
    en[tot]=y;
}
void ergodic(int x,int y)
{
    if(!y)
    {
        name[c[1]][c[2]][c[3]]=++tot;
        return;
    }
    if(x>4) return;
    for(int i=0;i<=y;i++)
    {
        c[x]=i;
        ergodic(x+1,y-i);
        c[x]=0;
    }
}
void dfs(int x,int y)
{
    if(!y)
    {
        int now=get(c);
        val[now]=calc(c);
        for(int i=1;i<=m;i++)
        {
            for(int j=1;j<=4;j++) d[j]=c[j]-a[i][j];
            bool pd=true;
            for(int j=1;j<=4;j++)
                if(d[j]<0)
                {
                    pd=false;
                    break;
                }
            if(!pd) continue;
            for(int j=1;j<=4;j++) d[j]+=b[i][j];
            int to=get(d);
            if(now==to) continue;
            edge[++cnt]=(data){now,to};
            insert(now,to);
        }
        return;
    }
    if(x>4) return;
    for(int i=0;i<=y;i++)
    {
        c[x]=i;
        dfs(x+1,y-i);
        c[x]=0;
    }
}
void tarjan(int x)
{
    dfn[x]=low[x]=++tot;
    bz[st[++top]=x]=true;
    for(int i=first[x];i;i=nex[i])
        if(!dfn[en[i]])
        {
            tarjan(en[i]);
            low[x]=min(low[x],low[en[i]]);
        }else
            if(bz[en[i]]) low[x]=min(low[x],dfn[en[i]]);
    if(dfn[x]==low[x])
    {
        num++;
        do
        {
            col[st[top]]=num;
            bz[st[top--]]=false;
        }while(st[top+1]^x);
    }
}
void work()
{
    int l=0,r=0;
    for(int i=1;i<=num;i++)
        if(!deg[i])
        {
            q[++r]=i;
            f[i]=val1[i];
            if(f[i]>ans) ans=f[i];
        }
    while(l<r)
    {
        int x=q[++l];
        for(int i=first[x];i;i=nex[i])
            if(f[x]+val1[en[i]]>f[en[i]])
            {
                f[en[i]]=f[x]+val1[en[i]];
                if(f[en[i]]>ans) ans=f[en[i]];
                q[++r]=en[i];
            }
    }
}
int main()
{
    freopen("return.in","r",stdin);
    freopen("return.out","w",stdout);
    scanf("%d%d",&n,&m);
    for(int i=1;i<=m;i++)
    {
        scanf("%s %s",s+1,t+1);
        int len=strlen(s+1);
        for(int j=1;j<=len;j++)
        {
            a[i][s[j]-'A'+1]++;
            b[i][t[j]-'A'+1]++;
        }
    }
    for(int i=0;i<=n;i++) g[i][0]=1;
    for(int i=1;i<=n;i++)
        for(int j=1;j<=i;j++) g[i][j]=g[i-1][j]+g[i-1][j-1];
    ergodic(1,n);
    int node=tot;
    tot=0;
    dfs(1,n);
    tot=0;
    for(int i=1;i<=node;i++)
        if(!dfn[i]) tarjan(i);
    for(int i=1;i<=node;i++) val1[col[i]]+=val[i];
    memset(first,tot=0,sizeof(first));
    for(int i=1;i<=cnt;i++)
        if(col[edge[i].x]^col[edge[i].y])
        {
            insert(col[edge[i].x],col[edge[i].y]);
            deg[col[edge[i].y]]++;
        }
    work();
    printf("%lld",ans);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/liyizhixl/article/details/80316372
今日推荐