hdu5405(dfs序+树链剖分+线段树)

题解:因为我要查询有多少\sum wiwj,i到j的路径上有与u到v路径上的公共点,那么我们可以先求没有经过u到v路径的上点的平方的总和然后再用所有点权值的总和的平方减去他,那么就是答案。然后我们怎么操作能?因为用树链剖分的话不会经过的点的对应的轻儿子的结点,所以我们记录一些对应轻儿子的权值平方的总和,然后我们如果某链往另外一条链上跳的话说明我这个点是那个点的轻链,我们再补回去就好了,然后还有一种情况是我重儿子没有经过那么我们要再减去这个重儿子子树权值总和的平方,择优我们查询就结束了

接着是更新,按到 上面查询我们需要维护的是子树权值总和,还有该点轻链对应的权值的平方的总和,那么我们子树权值总和用dfs序维护一下,然后对应该点轻链的为,我们需要计算一下更新这点轻链链头的子树权值总和的平方变化了多少然后更新一下,这样我们最多更新O(logn)次,

#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
#include<vector>
#include<cstdio>
#include<cmath>
#include<set>
#include<map>
#include<cstdlib>
#include<ctime>
#include<stack>
#include<bitset>
using namespace std;
#define mes(a,b) memset(a,b,sizeof(a))
#define rep(i,a,b) for(int i = a; i <= b; i++)
#define dec(i,a,b) for(int i = b; i >= a; i--)
#define pb push_back
#define fi first
#define se second
#define ls rt<<1
#define rs rt<<1|1
#define lson ls,L,mid
#define rson rs,mid+1,R
#define lowbit(x) x&(-x)
typedef double db;
typedef long long int ll;
typedef pair<int,int> pii;
typedef unsigned long long ull;
const ll inf = 0x3f3f3f3f;
const int mx = 1e5+5;
const int mod = 1e9+7;
const int x_move[] = {1,-1,0,0,1,1,-1,-1};
const int y_move[] = {0,0,1,-1,1,-1,1,-1};
int n,m;
ll w[mx];
ll a[mx];
ll s[mx];
int pos[mx];
int sz[mx];
int son[mx];
int fa[mx];
int dep[mx];
int top[mx];
int id[mx];
int p[mx];
int l[mx],r[mx];
ll sum[mx<<2],SUM[mx<<2];
int dfn;
vector<int>g[mx];
void dfs(int u,int pre){
	dep[u] = dep[pre]+1;
	fa[u] = pre;
	son[u] = 0;
	sz[u] = 1;
	a[u] = 0;
	l[u] = ++dfn;
	p[dfn] = u;
	for(auto v: g[u]){
		if(v==pre)
			continue;
		dfs(v,u);
		sz[u] += sz[v];
		w[u] += w[v];
		if(sz[son[u]]<sz[v])
			son[u] = v;
	}
	w[u] %= mod;
	r[u] = dfn;
}
void DFS(int u,int fa){
	top[u] = fa;
	id[u] = ++dfn;
	pos[dfn] = u;
	if(son[u])
		DFS(son[u],fa);
	for(auto v: g[u]){
		if(top[v])
			continue;
		a[u] = a[u]+w[v]*w[v]%mod;
		DFS(v,v);
	}
	a[u] %= mod;
}
void built(int rt,int L,int R){
	if(L==R){
		sum[rt] = s[p[L]];
		SUM[rt] = a[pos[L]];
		return;
	}
	int mid = (L+R)/2;
	built(lson);
	built(rson);
	sum[rt] = sum[ls]+sum[rs];
	SUM[rt] = SUM[ls]+SUM[rs];
}
ll update1(int rt,int L,int R,int p,ll v){
	if(L==R){
		v = v-sum[rt];
		sum[rt] += v;
		return v;
	}
	int mid = (L+R)/2;
	ll x;
	if(p>mid) x = update1(rson,p,v);
	else x = update1(lson,p,v);
	sum[rt] = sum[ls]+sum[rs];
	return x;
}
void update2(int rt,int L,int R,int p,ll v){
	if(L==R){
		SUM[rt] = (SUM[rt]+v)%mod;
		return;
	}
	int mid = (L+R)/2;
	if(p>mid) update2(rson,p,v);
	else update2(lson,p,v);
	SUM[rt] = SUM[ls]+SUM[rs];
}
ll query1(int rt,int L,int R,int l,int r){
	if(L==l&&R==r)
		return sum[rt];
	int mid = (L+R)/2;
	if(l>mid) return query1(rson,l,r);
	else if(r<=mid) return query1(lson,l,r);
	else return query1(lson,l,mid)+query1(rson,mid+1,r);
}

void change(int a,ll b){
	ll x = update1(1,1,n,l[a],b);
	while(fa[top[a]]){
		a = top[a];
		ll sum = query1(1,1,n,l[a],r[a]);
		sum = x*(2*sum%mod-x)%mod;
		if(sum<0)
			sum = sum+mod;
		update2(1,1,n,id[fa[top[a]]],sum);
		a = fa[a];
	}	
}
ll query2(int rt,int L,int R,int l,int r){
	if(L==l&&R==r)
		return SUM[rt];
	int mid = (L+R)/2;
	if(l>mid)return query2(rson,l,r);
	else if(r<=mid) return query2(lson,l,r);
	else return query2(lson,l,mid)+query2(rson,mid+1,r);
}
ll get_ans(int a,int b){
	ll ans = sum[1];
	ll tmp;
	if(ans>mod)
		ans = ans%mod;
	ans = ans*ans%mod;
	while(top[a]!=top[b]){
		if(dep[top[a]]<dep[top[b]])
			swap(a,b);
		if(son[a]){
			tmp = query1(1,1,n,l[son[a]],r[son[a]]);
			tmp %= mod;
			ans = (ans-tmp*tmp%mod)%mod;
		}
		ans = (ans-query2(1,1,n,id[top[a]],id[a])%mod)%mod;
		a = top[a];
		tmp = query1(1,1,n,l[a],r[a]);
		tmp %= mod;
		ans = (ans+tmp*tmp%mod)%mod;
		a = fa[a];
	}
	if(id[a]>id[b])
		swap(a,b);
	if(son[b]){
		tmp = query1(1,1,n,l[son[b]],r[son[b]]);
		tmp %= mod;
		ans = (ans-tmp*tmp%mod)%mod;
	}
	ans = (ans-query2(1,1,n,id[a],id[b])%mod)%mod;
	tmp = sum[1] - query1(1,1,n,l[a],r[a]);
	tmp %= mod;
	ans = (ans-tmp*tmp%mod)%mod;
	if(ans<0)
		ans += mod;
	return ans;
}
int main(){
	//freopen("test.in","r",stdin);
	//freopen("test.out","w",stdout);
	int t,q,ca = 1;
	while(~scanf("%d%d",&n,&q)){
		dfn = 0;
		for(int i = 1; i <= n; i++)
			scanf("%lld",&w[i]),g[i].clear(),s[i] = w[i],top[i] = 0;
		for(int i = 2; i <= n; i++){
			int u,v;
			scanf("%d%d",&u,&v);
			g[u].pb(v);
			g[v].pb(u);
		}
		dfs(1,0);
		dfn = 0;
		DFS(1,1);
		built(1,1,n);
		while(q--){
			int a,b,c;
			scanf("%d%d%d",&a,&b,&c);
			if(a==1) change(b,c);
			else	printf("%lld\n",get_ans(b,c));
		}
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/a1325136367/article/details/81094176
今日推荐