好咯,又是经典套路题(我太弱...)
又是差分!!!
我们要求一个点到所有黑点的距离,实际上就是求\[\sum Dis[x]+Dis[y]-2*Dis[Lca(x,y)]\]
主要就是\(Dis[Lca(x,y)]\)比较难求.
然后这里有一个经典套路,那就是对于一个修改,我们直接修改其到根这一条路径上的所有权值,即全部打一个加一标记。
然后询问的时候,也是直接查询到根这一条路径上所有边乘上其标记的和.
意会一下??
发现问题就solve了~
于是树链剖分上啊啊啊
#include <iostream>
#include <cstdio>
#include <cstring>
#define mx(a, b) ((a) = (a) > (b) ? (a) : (b))
#define F(i, a, b) for (int i = a; i <= b; i ++)
#define L x << 1
#define R x << 1 | 1
const int N = 1e5 + 10;
using namespace std;
int n, m, t, x, cnt, tmp;
int f[N], dfn[N], size[N], d[N], son[N], top[N], tov[N], nex[N], las[N], vis[N];
long long dis[N], tot, sum, ans;
struct node { long long ans, V, add; } tr[8 * N];
void link(int x, int y) { tov[++ tot] = y, nex[tot] = las[x], las[x] = tot; }
void DFS(int k) {
size[k] = 1;
for (int x = las[k], nmx = 0; x; x = nex[x])
dis[tov[x]] = dis[k] + d[tov[x]],
DFS(tov[x]),
size[k] += size[tov[x]], son[k] = size[tov[x]] > nmx ? tov[x] : son[k], mx(nmx, size[tov[x]]);
}
void GET(int k) {
dfn[k] = ++ cnt;
if (son[k])
top[son[k]] = top[k], GET(son[k]);
for (int x = las[k]; x; x = nex[x])
if (tov[x] ^ son[k]) top[tov[x]] = tov[x], GET(tov[x]);
}
void Modify(int x, int st, int en, int p, int t) {
if (st == en) { tr[x].V = t; return; } int m = st + en >> 1;
if (m >= p) Modify(L, st, m, p, t); else Modify(R, m + 1, en, p, t);
tr[x].V = tr[L].V + tr[R].V;
}
void Find(int x, int st, int en, int l, int r) {
if (l <= st && en <= r) { if (t == 1) tr[x].add ++, tr[x].ans += tr[x].V; else ans += tr[x].ans; return; } int m = st + en >> 1;
if (tr[x].add) tr[L].add += tr[x].add, tr[L].ans += tr[x].add * tr[L].V, tr[R].add += tr[x].add, tr[R].ans += tr[x].add * tr[R].V, tr[x].add = 0;
if (m >= l) Find(L, st, m, l, r);
if (m < r) Find(R, m + 1, en, l, r);
tr[x].ans = tr[L].ans + tr[R].ans;
}
int main() {
scanf("%d%d", &n, &m);
F(i, 1, n - 1) scanf("%d", &f[i]), link(f[i], i);
F(i, 1, n - 1) scanf("%d", &d[i]);
DFS(0);
GET(0);
tot = sum = 0;
F(i, 1, n - 1) Modify(1, 1, cnt, dfn[i], d[i]);
F(i, 1, m) {
scanf("%d%d", &t, &x), ans = 0, tmp = x;
if (t == 1) tot += vis[x] == 0, sum += dis[x] * (vis[x] == 0), vis[x] ++;
if (t == 1 && vis[x] > 1) continue;
while (x) {
Find(1, 1, cnt, dfn[top[x]], dfn[x]);
x = f[top[x]];
}
if (t == 2) printf("%lld\n", dis[tmp] * tot + sum - 2 * ans);
}
}