2020牛客暑期多校训练营第四场Operating on the Tree

Operating on the Tree

原题请看这里

题目描述:

此问题是由问题G(Operating on a Graph )启发的。 因此,您需要阅读它的声明才能解决此问题。Operating on a Graph题目描述+题解
您将得到一棵具有 n n 个顶点的树。 假设 p p 是从 0 0 n 1 n-1 的排列。 我们定义函数 f p f(p) 如下:假设给定的树是问题G的输入图,而 p p 是输入运算符序列。 f p f(p) 是满足条件的操作数:执行第 i i 个操作时,至少有一个顶点属于 O i O_i 组。令 S S 为从 0 0 n 1 n-1 的所有可能排列的集合。 请计算( \sum p \in S f p f(p)   m o d \ mod 998244353 998244353

输入描述:

第一行包含一个整数 t t 1 1 \le t \le 500 500 表示测试用例的数量,每个测试包含两行。 第一行包含一个整数n,代表给定树中的顶点数 1 1 \le n n \le 2000 2000 。 第二行包含 n 1 n-1 个非负整数 a 1 a_1 a 2 a_2 \ldots a n 1 a_ {n-1} 。 它表示树的第 i i 个边缘连接顶点 i i a i a_i ( ( a i a_i < < i i ) ) 所有测试用例的 n n 之和不超过 2000 2000

输出描述:

对于每个测试,输出一行,其中包含一个表示答案的整数,范围是 [ 0 , 998244352 ] [0,998244352]

样例输入:

3
4
0 1 2
4
0 1 1
2
0

样例输出:

48
60
2

思路:

树形DP
根据题意,我们可以知道
1.没有两个好点是相邻的
2.每个坏点都至少与一个比他大的好点相邻
看到这里,聪明的你一定已经想到用什么方法解着道题了吧
那就是树形DP
dp数组开三维:dp[MAXN][3][MAXN]:
第一维表示当前节点,第二维表示成功/失败/尚未失败,第三维表示子儿子中有几个成功数
这样我们就可以分三类讨论:
树根是好点,坏点但尚未有比它大的好点相邻,坏点已有比它大的好点相邻
在此基础上我们又可以分三类讨论:
树根是好点,坏点但尚未有比它大的好点相邻,坏点已有比它大的好点相邻
这样就要分九种情况分类讨论 啊啊啊啊好烦!
大体就是这样的,具体细节详见代码注释

AC Code:

#include<bits/stdc++.h>
using namespace std;
const int MAXN=2e3+5;
const int mod=998244353;
vector<int> e[MAXN];
int comb[MAXN][MAXN],sz[MAXN],dp[MAXN][3][MAXN],dp1[MAXN][3][MAXN],tmp[3][MAXN],tmp1[3][MAXN];
void add(int &u,int v)
{
	u+=v;
	u-=(u>=mod?mod:0);
}
void dfs(int u)
{
    sz[u]=dp[u][0][0]=dp[u][2][0]=1;
    for(int vec=0,v;vec<e[u].size();vec++)
    {
		v=e[u][vec];
		dfs(v);
        for(int i=0;i<3;i++)
            for(int j=1;j<sz[v];j++)
            {
                add(dp[v][i][j],dp[v][i][j-1]);
                add(dp1[v][i][j],dp1[v][i][j-1]);
            }//更新dp的值为前缀和,便于后续计算。注:此时dp含义已经变化,dp[i][sta][j]变成了最多j个节点比i大时的情况 
        for(int i=0;i<sz[u];i++)//枚举v之前的子树中,比x大的方案数 
            for(int j=0;j<=sz[v];j++)
            {//枚举v子树中,比x大的方案数 
                int coe=1ll*comb[i+j][i]*comb[sz[u]-1-i+sz[v]-j][sz[v]-j]%mod;
                //恰好共有i+j个节点比u大,且其中j个节点属于v子树的方案数 
                //=比u大的i+j个点,有i个点是v之前的子树中的方案数 * 比x小的sz[u]-1-i+sz[v]-j个点中,有sz[v]-j个点在v子树中的方案。 
				for(int type1=0;type1<3;type1++)
                    for(int type2=0;type2<3;type2++)
                    {
                    	// v节点在u之前的情况,即比u大的i+j个节点中,最多有j-1个节点属于v节点的情况 
                        int cnt=j?dp[v][type2][j-1]:0;
                        //v子树中最多有j-1个点比v大的方案数 
                        int cnt1=j?dp1[v][type2][j-1]:0;
                        //v子树中,最多有j-1个点比v大时的v子树的贡献 
                        int coe1=1ll*coe*dp[u][type1][i]%mod*cnt%mod;
                        //v节点比u节大的方案数 
                        int base=coe*(1ll*dp[u][type1][i]*cnt1%mod+1ll*dp1[u][type1][i]*cnt%mod)%mod;
                        //v节点比u节点大时,u节点和v子树的贡献 
						if(!type1)
						{
                            if(type2==1)
                            {
                                add(tmp[0][i+j],coe1);
                                add(tmp1[0][i+j],base);
                            }//u好和v坏的状态更新到u好 
                        }
                        else if(type1==1)
                        {
                        	if(!type2||type2==1)
                        	{
                                add(tmp[1][i+j],coe1);
                                add(tmp1[1][i+j],base);
                            }//u坏和v好/坏的状态,更新到u坏 
                    	}
						else if(type1==2)
						{
                            if(!type2)
                            {
                                add(tmp[1][i+j],coe1);
                                add(tmp1[1][i+j],base);//u半坏和v好的状态,更新到u坏 
                            }
                            else if(type2==1)
                            {
                                add(tmp[2][i+j],coe1);
                                add(tmp1[2][i+j],base);
                            }//u半坏和v坏的状态,更新到u半坏 
                        }//v节点在u之后的情况,即比u大的i+j个节点中,至少有j个节点属于v节点的情况,与上一种情况对立 
						cnt=dp[v][type2][sz[v]-1]-cnt;
                        cnt+=cnt<0? mod:0;
                        //v子树中至少有j个节点比v大的方案数 
                        cnt1=dp1[v][type2][sz[v]-1]-cnt1;
                        cnt1+=cnt1<0? mod:0;
                        //v子树中,至少有j个点比v大时的v子树的贡献 
                        coe1=1ll*coe*dp[u][type1][i]%mod*cnt%mod;
                        //v节点比u节小的方案数 
                        base=coe*(1ll*dp[u][type1][i]*cnt1%mod+1ll*dp1[u][type1][i]*cnt%mod)%mod;
                        //v节点比u节点小时,u节点和v子树的贡献 
						if(!type1)
						{
                        	if(type2==1||type2==2)
                        	{
                                add(tmp[0][i+j],coe1);
                                add(tmp1[0][i+j],base);
                            }//u好和v坏/半坏,更新到u好 
						}
						else if(type1==1)
						{
                        	if(!type2||type2==1)
                        	{
                                add(tmp[1][i+j],coe1);
                                add(tmp1[1][i+j],base);
                            }//u坏和v好/坏,更新到u坏 
						}
						else if(type1==2)
						{
                        	if(!type2||type2==1)
                        	{
                        		//u半坏和v好/坏的状态,更新到u半坏。
                        		//因为v是在u之后,所以在u半坏是因为u的父亲在u之前导致的,u之后是允许v好的
                                add(tmp[2][i+j],coe1);
                                add(tmp1[2][i+j],base);
                            }
						}
                    }
            }
           sz[u]+=sz[v];
        for(int i=0;i<sz[u];i++)
            for(int j=0;j<3;j++)
            {
                dp[u][j][i]=tmp[j][i];
				dp1[u][j][i]=tmp1[j][i];
                tmp[j][i]=tmp1[j][i]=0;
            }//拷贝到dp上
    }
    for(int i=0;i<sz[u];i++)
        add(dp1[u][0][i],dp[u][0][i]);//加上u自己的贡献 
}
int n,t,ans;
int main()
{
    for(int i=0;i<MAXN;i++)
    {
        comb[i][0]=1;
        for(int j=1;j<=i;j++)
            comb[i][j]=(comb[i-1][j-1]+comb[i-1][j])%mod;//杨辉三角算组合数 
    }
    scanf("%d",&t);
    while(t--)
    {
		scanf("%d",&n);
    	for(int i=0;i<=n;i++)
    	{
        	e[i].clear();
        	memset(dp[i],0,sizeof(dp[i]));
        	memset(dp1[i],0,sizeof(dp1[i]));
    	}
    	for(int i=2,fa;i<=n;i++)
    	{
        	scanf("%d",&fa);
        	fa++;e[fa].push_back(i);
    	}
    	dfs(1);
    	ans=0;
   	 	for(int i=0;i<n;i++)
   	 	{
        	add(ans,dp1[1][0][i]);
        	add(ans,dp1[1][1][i]);
        	//取模加法,由于用减法替代除法取模,因此会算得更快 
    	}
    	printf("%d\n",ans);
	}
}

去注释的代码在这里!

猜你喜欢

转载自blog.csdn.net/s260127ljy/article/details/107546599
今日推荐