HDU - 6035,点分治,组合计数

Colorful Tree

https://vjudge.net/problem/938230/origin
There is a tree with n nodes, each of which has a type of color represented by an integer, where the color of node i is ci.

The path between each two different nodes is unique, of which we define the value as the number of different colors appearing in it.

Calculate the sum of values of all paths on the tree that has n(n−1)2 paths in total.

题意:给你一棵树有n个点,每个点有一种颜色(1 ~ n),一条路径的权值为这条路径上不同颜色的种数,问所有路径的权值和为多少
思路:首先点分治时我们要求经过一个重心root的所有路径的贡献,这时是一颗以重心为根节点的树,对于其中一颗子树,如果id1,id2两点颜色相同且id1是id2的祖先,那么对于经过id2的所有路径(且经过root),必定经过id1,那么只有id1的颜色会产生贡献,我们只要计算在子树中最先出现的每种颜色的贡献即可(即经过该点的路径数),如果id是在一颗子树最先出现某种颜色的节点,以它为根节点的子树的节点数为sz[id],点v是整个子树的根节点,val[id]为节点id的颜色,num[val[id]]初始值为重心root所在整个树的节点数,那么节点v的贡献为(num[val[id]] - sz[v])*sz[id](因为要经过重心,且要经过节点id),之后num[val[id]]的值要减去sz[id],因为当计算其他子树某种颜色的贡献时路径如果还经过id的子节点便会产生重复计算,注意如果两个节点颜色相同但不存在祖先关系两者贡献都得算(在这里wa了一发)

#include<bits/stdc++.h>
#define MAXN 200010
#define INF 0x3f3f3f3f
#define ll long long
using namespace std;
int head[MAXN],tot;
struct edge
{
    int v,nxt;
}edg[MAXN << 1];
inline void addedg(int u,int v)
{
    edg[tot].v = v;
    edg[tot].nxt = head[u];
    head[u] = tot++;
}
int n,mx,root,Size,sz[MAXN];
ll ans;
int num[MAXN],val[MAXN];//num记录每种颜色在正在处理的子树中还应有的次数
bool vis[MAXN];
inline void getroot(int u,int f)
{
    int v,mson = 0;
    sz[u] = 1;
    num[val[u]] = Size;
    for(int i = head[u];i != -1;i = edg[i].nxt)
    {
        v = edg[i].v;
        if(v == f || vis[v]) continue;
        getroot(v,u);
        sz[u] += sz[v];
        mson = max(mson,sz[v]);
    }
    mson = max(Size-sz[u],mson);
    if(mson < mx)
        mx = mson,root = u;
}
int color[MAXN],id[MAXN],cnt,viscolor[MAXN],num1[MAXN];//num1记录各子树除第一次出现与根节点相同颜色的节点到根节点的路径长度
inline void getdis(int u,int f)
{
    sz[u] = 1;
    bool flag = false;
    if(!viscolor[val[u]])
        viscolor[val[u]] = 1,flag = true,color[++cnt] = val[u],id[cnt] = u;
    int v;
    for(int i = head[u];i != -1;i = edg[i].nxt)
    {
        v = edg[i].v;
        if(v == f || vis[v]) continue;
        getdis(v,u);
        sz[u] += sz[v];
    }
    if(flag)
        viscolor[val[u]] = 0;
}
inline void solve(int u,int ssize)
{
    vis[u] = 1;
    int v;
    for(int i = head[u];i != -1;i = edg[i].nxt)
    {
        v = edg[i].v;
        if(vis[v]) continue;
        cnt = 0;
        getdis(v,v);
        num1[v] = sz[v];
        for(int i = 1;i <= cnt;++i)
        {
            int nn = id[i];
            ans += 1ll * sz[nn] * (num[color[i]]-sz[v]);
            if(color[i] == val[u])
                num1[v] -= sz[nn];
        }
        for(int i = 1;i <= cnt;++i)
            num[color[i]] -= sz[id[i]];
    }
    for(int i = head[u];i != -1;i = edg[i].nxt)
    {
        v = edg[i].v;
        if(vis[v]) continue;
        ans += 1ll*num1[v] * (num[val[u]] - num1[v]);
        num[val[u]] -= num1[v];
    }
    for(int i = head[u];i != -1;i = edg[i].nxt)
    {
        v = edg[i].v;
        if(vis[v]) continue;
        Size = sz[v];
        mx = INF;
        getroot(v,v);
        solve(root,Size);
    }
}
inline void init()
{
    tot = ans = 0,Size = n,mx = INF;
    memset(head,-1,sizeof(int)*(n+1));
    memset(vis,false,sizeof(bool)*(n+1));
}
int main()
{
    int t = 0;
    while(~scanf("%d",&n))
    {
        ++t;
        init();
        for(int i = 1;i <= n;++i)
            scanf("%d",&val[i]);
        int u,v;
        for(int i = 1;i < n;++i)
        {
            scanf("%d%d",&u,&v);
            addedg(u,v),addedg(v,u);
        }
        getroot(1,1);
        solve(root,Size);
        printf("Case #%d: %lld\n",t,ans);
    }
    return 0;
}
/*
7
6 3 3 1 1 1 2
2 1
3 1
4 1
5 4
6 5
7 5

 */
发布了50 篇原创文章 · 获赞 3 · 访问量 3096

猜你喜欢

转载自blog.csdn.net/xing_mo/article/details/104032521
今日推荐