【Codeforces 736C】 Ostap and Tree【树形DP】

题意:

问对一棵树染色,初始无色,要求距每个点最近的染色点的距离不超过k

题解:

写的比较蠢,dp[i][j]记录第i个点关键色点深度为j的种类数,关键点:这棵树u如果已经满足了,j就是离u最近色点;如果不满足,j是最远色点。枚举子树,考虑容斥,设计叶子记录。。。。最大长和(2k+1)

有更好的大概https://blog.csdn.net/sjtsjt709/article/details/53428615

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<vector>

using namespace std;
long long dp[105][85];


int n,k;

vector<int>e[105];

long long mod=1e9+7;

void dfs(int u,int fa)
{
    int len=e[u].size();
    //cout<<"len"<<len<<endl;
    if(len==1 && fa==e[u][0])
    {
        dp[u][k+1]=1;
        dp[u][0]=1;
        //cout<<"leaf:"<<u<<endl;
    }
    else
    {
        for(int i=0;i<len;i++)
        {
            int v=e[u][i];
            if(v!=fa)
            {
                dfs(v,u);
            }
        }

        for(int i=0;i<=k;i++)
        {
            if(i==0)
            {
                dp[u][0]=1;
                for(int j=0;j<len;j++)
                {
                    int v=e[u][j];
                    if(v==fa)
                        continue;

                    long long sum=0;
                    for(int l=0;l<=2*k;l++)
                    {
                        sum=(sum+dp[v][l])%mod;
                    }
                    dp[u][0]=(dp[u][0]*sum)%mod;
                }
            }
            else
            {
                long long m1=1,m2=1;
                for(int j=0;j<len;j++)
                {

                    int v=e[u][j];
                    if(v==fa)
                        continue;
                    long long sum=0;
                    for(int l=i-1;l<=(2*k-1-(i-1));l++)
                    {
                        sum=(sum+dp[v][l])%mod;
                    }
                    m1=m1*sum%mod;
                    m2=m2*(sum-dp[v][i-1])%mod;
                }
                dp[u][i]=(m1-m2)%mod;
            }
        }

        for(int i=k+1;i<=(2*k);i++)
        {
            long long m1=1,m2=1;
            for(int j=0;j<len;j++)
            {
                int v=e[u][j];
                if(v==fa)
                    continue;

                long long sum=0;
                for(int l=(2*k-(i-1));l<=i-1;l++)
                {
                    sum=(sum+dp[v][l])%mod;
                }
                m1=m1*sum%mod;
                m2=m2*(sum-dp[v][i-1])%mod;
            }
            dp[u][i]=(m1-m2)%mod;
        }

    }

}

int main()
{
    while(~scanf("%d%d",&n,&k))
    {

          //memset(dp,sizeof(dp),0);
          for(int i=0;i<=100;i++)
          for(int j=0;j<=80;j++)
          dp[i][j]=0;

        //vector<int>e[105];

        for(int i=0;i<=100;i++)
           e[i].clear();
        for(int i=0;i<n-1;i++)
        {
            int u,v;
            scanf("%d%d",&u,&v);
            e[u].push_back(v);
            e[v].push_back(u);
        }

        dfs(1,-1);
        long long ans=0;
        for(int i=0;i<=k;i++)
        {
            ans=(ans+dp[1][i])%mod;
            //cout<<i<<" : "<<dp[1][i]<<endl;
        }

        //cout<<dp[2][0]<<"!!"<<dp[2][2]<<endl;
        ans=(ans+mod)%mod;
        printf("%lld\n",ans);
    }
}

猜你喜欢

转载自blog.csdn.net/c_czl/article/details/83451542
今日推荐