洛谷P2664 树上游戏——点分治

原题链接
被点分治虐的心态爆炸了

题解

发现直接统计路径上的颜色数量很难,考虑转化一下统计方式。对于某一种颜色\(c\),它对一个点的贡献为从这个点出发且包含这种颜色的路径条数。
于是我们先点分一下,然后分别统计经过分治中心的路径对根和对其他点的贡献就行了。
推荐一篇比较详细的题解
代码:

#include <bits/stdc++.h>

using namespace std;

#define N 100000
#define pb push_back
#define ll long long

int n, c[N + 5];
vector<int> G[N + 5];
int root, S, sz[N + 5], vis[N + 5], maxsz[N + 5], col[N + 5], w[N + 5];
ll cnt[N + 5], ans[N + 5], sum1, sum2;

void getRoot(int u, int pa) {
    sz[u] = 1; maxsz[u] = 0;
    for (auto v : G[u]) {
        if (v == pa || vis[v]) continue;
        getRoot(v, u);
        sz[u] += sz[v];
        maxsz[u] = max(maxsz[u], sz[v]);
    }
    maxsz[u] = max(maxsz[u], S - sz[u]);
    if (!root || maxsz[u] < maxsz[root]) root = u;
}

void dfs0(int u, int pa) {
    sz[u] = 1;
    for (auto v : G[u]) {
        if (v == pa || vis[v]) continue;
        dfs0(v, u);
        sz[u] += sz[v];
    }
}

void dfs1(int u, int pa) { // 计算w数组
    col[c[u]]++;
    w[u] = 0;
    if (col[c[u]] == 1) w[u] = sz[u];
    sum1 += w[u], cnt[c[u]] += w[u];
    for (auto v : G[u]) {
        if (v == pa || vis[v]) continue;
        dfs1(v, u);
    }
    col[c[u]]--;
}

void dfs2(int u, int pa, int k) {
    cnt[c[u]] += k * w[u];
    sum1 += k * w[u];
    for (auto v : G[u]) {
        if (v == pa || vis[v]) continue;
        dfs2(v, u, k);
    }
}

void dfs3(int u, int pa) {
    col[c[u]]++;
    if (col[c[u]] == 1) sum2 += S - cnt[c[u]];
    ans[u] += sum1 + sum2;
    for (auto v : G[u]) {
        if (v == pa || vis[v]) continue;
        dfs3(v, u);
    }
    col[c[u]]--;
    if (col[c[u]] == 0) sum2 -= S - cnt[c[u]];
}

void clear(int u, int pa) {
    cnt[c[u]] = 0;
    for (auto v : G[u]) {
        if (v == pa || vis[v]) continue;
        clear(v, u);
    }
}

void calc(int u) {
    dfs0(u, 0);
    S = sz[u], sum1 = 0;
    for (auto v : G[u]) {
        if (vis[v]) continue;
        dfs1(v, u);
    }
    ans[u] += S + sum1 - cnt[c[u]];
    for (auto v : G[u]) {
        if (vis[v]) continue;
        dfs2(v, u, -1);
        S -= sz[v];
        col[c[u]]++;
        sum2 = S - cnt[c[u]];
        dfs3(v, u);
        col[c[u]]--;
        S += sz[v];
        dfs2(v, u, +1);
    }
    clear(u, 0);
}

void solve(int u) {
    vis[u] = 1;
    calc(u);
    for (auto v : G[u]) {
        if (vis[v]) continue;
        root = 0, S = sz[v], getRoot(v, u);
        solve(root);
    }
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i)
        scanf("%d", &c[i]);
    for (int i = 1, x, y; i < n; ++i) {
        scanf("%d%d", &x, &y);
        G[x].pb(y), G[y].pb(x);
    }
    root = 0, S = n, getRoot(1, 0);
    solve(root);
    for (int i = 1; i <= n; ++i) printf("%lld\n", ans[i]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/dummyummy/p/11137061.html
今日推荐