洛谷 P2664 树上游戏 点分治

题目链接

题意:一个有 n n 个节点的树,每个节点有相应的颜色,定义一个 s ( i j ) s(i,j) 表示从 i i j j 的路径上不同的颜色数量, s u m ( i ) = j = 1 n s ( i j ) sum(i) = \sum_{j=1}^{n}s(i,j) ,现在要求所有的 s u m sum

思路:这种求树上点对之间关系的,一般就是点分治了。那么问题就转化为如何求一棵树上一对经过根的点的 s ( i j ) s(i,j)

转化下,就是求一棵树上每种颜色从别的子树上对当前这个子树上的某个点的贡献值。

所以我们先统计下整棵树上所有颜色的贡献值 c o l o r color ,对于某个颜色,它的贡献为每条该树上的链上它第一次出现的节点 v v n u m num 值的和。( n u m [ v ] num[v] 为当前这个有根树以节点 v v 为根提取出来的子树的大小)

显然整个树的根所得的贡献即为该树所有颜色的 c o l o r color 的和。

现在讨论一棵子树上的点,由于是求其他子树上颜色对该点的贡献,所以要先遍历一遍该子树,把该子树产生的贡献在相应的 c o l o r color 里清除,在计算完该子树之后要记得恢复贡献。

我们记录从子树上一点到根的链上出现过的颜色的 c o l o r color 和,易得这些颜色对该点造成的贡献就不是相应的 c o l o r color 了,而是除去该子树外的所有点。

所以该点所得到的贡献应该 = = 总的 c o l o r color - 链上出现的 c o l o r color + + 已出现的颜色数量 * 除去该子树外点的数量

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<set>
#include<string>
#include<cstring>
#include<vector>
using namespace std;
#define ffor(i,d,u) for(int i=(d);i<=(u);++i)
#define _ffor(i,u,d) for(int i=(u);i>=(d);--i)
#define NUM 100005
#define LL long long
int n;
int head[NUM] = {}, edge_num = 0;
struct edge
{
    int to, next;
} e[NUM << 1];
template <typename T>
void read(T& x)
{
    x=0;
    char c;T t=1;
    while(((c=getchar())<'0'||c>'9')&&c!='-');
    if(c=='-'){t=-1;c=getchar();}
    do(x*=10)+=(c-'0');while((c=getchar())>='0'&&c<='9');
    x*=t;
}
template <typename T>
void write(T x)
{
    int len=0;char c[21];
    if(x<0)putchar('-'),x*=(-1);
    do{++len;c[len]=(x%10)+'0';}while(x/=10);
    _ffor(i, len, 1) putchar(c[i]);
}

int tot, root, min_max_num, num[NUM];
bool vis[NUM] = {};
void get_root(int v, int pre)
{
    int x, max_num = 0;
    num[v] = 1;
    for (int i = head[v]; i; i = e[i].next)
    {
        x = e[i].to;
        if (x == pre || vis[x])
            continue;
        get_root(x, v);
        num[v] += num[x];
        max_num = max(max_num, num[x]);
    }
    max_num = max(max_num, tot - num[v]);
    if(max_num <= min_max_num)
        root = v, min_max_num = max_num;
}

int c[NUM], cnt[NUM] = {}, color_ss[NUM];
LL sum, color[NUM] = {};
void calc_color(int v, int pre)
{
    int x;
    num[v] = 1, ++cnt[c[v]];
    for (int i = head[v]; i; i = e[i].next)
    {
        x = e[i].to;
        if (x == pre || vis[x])
            continue;
        calc_color(x, v);
        num[v] += num[x];
    }
    if (cnt[c[v]] == 1)
        color[c[v]] += num[v], sum += num[v], color_ss[++color_ss[0]] = c[v];
    --cnt[c[v]];
}
void change(int v, int pre, LL flag)
{
    int x;
    ++cnt[c[v]];
    for (int i = head[v]; i; i = e[i].next)
    {
        x = e[i].to;
        if (x == pre || vis[x])
            continue;
        change(x, v, flag);
    }
    flag *= num[v];
    if (cnt[c[v]] == 1)
        sum += flag, color[c[v]] += flag;
    --cnt[c[v]];
}

LL ans[NUM] = {}, tree_size;
void calc(int v, int pre, LL YY)
{
    int x;
    ++cnt[c[v]];
    if (cnt[c[v]] == 1)
        YY = YY - color[c[v]] + tree_size;
    ans[v] += YY;
    for (int i = head[v]; i; i = e[i].next)
    {
        x = e[i].to;
        if (x == pre || vis[x])
            continue;
        calc(x, v, YY);
    }
    --cnt[c[v]];
}
void solve(int v)
{
    int x;
    sum = 0, vis[v] = true, calc_color(v, 0); //计算整棵子树中各种颜色的贡献,包括根的颜色
    ans[v] += sum;                            //所有颜色的贡献和即是对根的贡献
    sum -= color[c[v]], cnt[c[v]] = 1;        //不考虑根
    for (int i = head[v]; i; i = e[i].next)
    {
        x = e[i].to;
        if(vis[x])
            continue;
        tree_size = num[v] - num[x];
        change(x, v, -1);                    //清除当前子树对于颜色的贡献
        calc(x, v, sum + tree_size);      //计算当前子树中每个节点,当前树中所有颜色从其他子树对其的贡献
        change(x, v, 1);                     //恢复清除掉的贡献
    }
    cnt[c[v]] = 0;
    while(color_ss[0])
        color[color_ss[color_ss[0]]] = 0, --color_ss[0];
    for (int i = head[root]; i; i = e[i].next)
    {
        x = e[i].to;
        if (!vis[x])
        {
            min_max_num = tot = num[x], root = x;
            get_root(x, 0), solve(root);
        }
    }
}
inline void add(int x, int y)
{
    e[++edge_num] = edge{y, head[x]};
    head[x] = edge_num;
}
inline void AC()
{
    int x, y;
    read(n);
    ffor(i, 1, n) read(c[i]);
    ffor(i, 2, n)
    {
        read(x), read(y);
        add(x, y), add(y, x);
    }
    tot = min_max_num = n, root = color_ss[0] = 0;
    get_root(1, 0), solve(root);
    ffor(i, 1, n) write(ans[i]), putchar('\n');
}
int main()
{
    AC();
    return 0;
}

猜你喜欢

转载自blog.csdn.net/a302549450/article/details/87926538
今日推荐