[POJ1741]Tree

题意:给一棵带边权的树,统计距离$\leq k$的点对数量

这个这么基础的东西居然鸽了那么久,我还是退役吧...

点分治用于统计满足某些性质的点对或路径,但实际上就是个大暴力

这题要求统计树上距离$\leq k$的点对数量,那么我们这样做:

对于当前节点$x$,遍历子树并算出$x$的所有后代到$x$的距离$dis_u$,将其排序后用双指针统计$dis_i+dis_j\leq k$的对数,记这种操作为$\text{calc}(x)$

显然这样会计算来自$x$的同一子树中的点对,而这些点对的树上路径并不经过$x$,所以我们还需要对$x$的所有儿子$u$,把答案减去$\text{calc}(u)$,这样我们就成功统计了过$x$的所有满足条件的点对

但我们还要统计不经过$x$的点对数量,那么直接对$x$的每个儿子递归做上述操作即可

容易发现,每次的计算量就是$siz_x$,而我们想要让这个总和最小,那么每次先找重心,再以重心为根进行后续操作即可

所以整的过程大概是这样:令$\text{solve}(x)$表示求解以$x$为根的子树内的答案

找到$x$子树内的重心$c$,答案$+=\text{calc}(c)$,把$c$设为在后续递归中不得访问,对$c$的每个儿子$u$,答案$-=\text{calc}(u)$,递归调用$\text{solve}(u)$,然后就没了

点分治统计答案的部分还有另一种写法,这种写法不需要去重,但要统计两个集合之间产生的贡献

同样是找重心$c$,初始时令$S=\varnothing$,对$c$的每个儿子$u$,统计$u$子树内的点和$S$中的点产生的贡献(对每个$u$子树内的点,找$S$中有多少个点可以和它组成符合条件的点对),然后令$S\gets S\cup\{u\}$,其他部分与第一种方法完 全 一 致,这种写法通常需要数据结构辅助查询,各有各的优缺点

总算填了一个坑?(大雾

#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
const int inf=2147483647;
int h[10010],nex[20010],to[20010],w[20010],t[10010],siz[10010],M,k,ans;
bool v[10010];
void add(int a,int b,int c){
	M++;
	to[M]=b;
	w[M]=c;
	nex[M]=h[a];
	h[a]=M;
}
int n,mn,cn;
void dfs1(int fa,int x){
	n++;
	siz[x]=1;
	for(int i=h[x];i;i=nex[i]){
		if(!v[to[i]]&&to[i]!=fa){
			dfs1(x,to[i]);
			siz[x]+=siz[to[i]];
		}
	}
}
void dfs2(int fa,int x){
	int i,k;
	k=0;
	for(i=h[x];i;i=nex[i]){
		if(!v[to[i]]&&to[i]!=fa){
			dfs2(x,to[i]);
			k=max(k,siz[to[i]]);
		}
	}
	k=max(k,n-siz[x]);
	if(k<mn){
		mn=k;
		cn=x;
	}
}
void dfs3(int fa,int x,int d){
	t[++M]=d;
	for(int i=h[x];i;i=nex[i]){
		if(!v[to[i]]&&to[i]!=fa)dfs3(x,to[i],d+w[i]);
	}
}
int calc(int x,int dt){
	M=0;
	dfs3(0,x,dt);
	sort(t+1,t+M+1);
	int l=1,r=M,res=0;
	while(l<r){
		if(t[l]+t[r]<=k){
			res+=r-l;
			l++;
		}else
			r--;
	}
	return res;
}
void solve(int x){
	int i;
	n=0;
	dfs1(0,x);
	mn=inf;
	dfs2(0,x);
	x=cn;
	ans+=calc(x,0);
	v[x]=1;
	for(i=h[x];i;i=nex[i]){
		if(!v[to[i]]){
			ans-=calc(to[i],w[i]);
			solve(to[i]);
		}
	}
}
int main(){
	int n,i,a,b,c;
	scanf("%d%d",&n,&k);
	while(n|k){
		M=0;
		memset(h,0,sizeof(h));
		memset(v,0,sizeof(v));
		for(i=1;i<n;i++){
			scanf("%d%d%d",&a,&b,&c);
			add(a,b,c);
			add(b,a,c);
		}
		ans=0;
		solve(1);
		printf("%d\n",ans);
		scanf("%d%d",&n,&k);
	}
}

猜你喜欢

转载自www.cnblogs.com/jefflyy/p/8960666.html
今日推荐