POJ - 1741 Tree(点分治模板题)

题目链接:点击查看

题目大意:给出一棵 n 个节点的树,现在定义 dis( x , y ) 为点 x 和点 y 之间的路径长度,现在问 dis ( x , y ) <= k 的点对有多少

题目分析:点分治的模板题目,干货博客:https://www.cnblogs.com/PinkRabbit/p/8593080.html

自己写的时候写了一堆bug。。提示一下,如果是WA的话可能有点无从下手,但如果用的是链式前向星,还仍然 TLE 的话,大概率是重心的地方出现细节问题了,因为如果重心使用不当,会将时间复杂度退化为 n*n*logn 

代码:

#include<iostream>
#include<cstdio>
#include<string>
#include<ctime>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<stack>
#include<climits>
#include<queue>
#include<map>
#include<set>
#include<sstream>
#include<cassert>
#include<bitset>
using namespace std;
 
typedef long long LL;
 
typedef unsigned long long ull;
 
const int inf=0x3f3f3f3f;
 
const int N=1e4+100;

int ans,root,Tsize,wt[N],sz[N],path[N],tot,cnt,n,k;

int nt[N<<1],to[N<<1],w[N<<1],head[N];

bool vis[N];

void addedge(int u,int v,int val)
{
	to[cnt]=v;
	w[cnt]=val;
	nt[cnt]=head[u];
	head[u]=cnt++;
}

void get_root(int u,int fa)
{
	sz[u]=1;
	wt[u]=0;
	for(int i=head[u];i!=-1;i=nt[i])
	{
		if(to[i]==fa||vis[to[i]])
			continue;
		get_root(to[i],u);
		sz[u]+=sz[to[i]];
		wt[u]=max(wt[u],sz[to[i]]);
	}
	wt[u]=max(wt[u],Tsize-sz[u]);
	if(wt[root]>wt[u])
		root=u;
}

void get_path(int u,int fa,int deep)
{
	path[++tot]=deep;
	for(int i=head[u];i!=-1;i=nt[i])
	{
		if(to[i]==fa||vis[to[i]])
			continue;
		get_path(to[i],u,deep+w[i]);
	}
}

int calc(int u,int deep)
{
	tot=0;
	get_path(u,-1,deep);
	sort(path+1,path+1+tot);
	int ans=0,r=tot;
	for(int l=1;l<=tot;l++)
	{
		while(r&&path[l]+path[r]>k)
			r--;
		if(r<l)
			break;
		ans+=r-l;
	}
	return ans;
}

void solve(int u)
{
	ans+=calc(u,0);
	vis[u]=true;
	for(int i=head[u];i!=-1;i=nt[i])
	{
		if(vis[to[i]])
			continue;
		ans-=calc(to[i],w[i]);
		root=0;
		Tsize=sz[to[i]];
		get_root(to[i],-1);
		solve(root);
	}
}

void init(int n)
{
	Tsize=n;
	cnt=root=ans=0;
	wt[root]=inf;
	memset(head,-1,sizeof(int)*(n+5));
	memset(vis,false,n+5);
}

int main()
{
#ifndef ONLINE_JUDGE
//  freopen("data.in.txt","r",stdin);
//  freopen("data.out.txt","w",stdout);
#endif
//  ios::sync_with_stdio(false);
	while(scanf("%d%d",&n,&k)!=EOF&&n+k)
	{
		init(n);
		for(int i=1;i<n;i++)
		{
			int u,v,w;
			scanf("%d%d%d",&u,&v,&w);
			addedge(u,v,w);
			addedge(v,u,w);
		}
		get_root(1,-1);
		solve(root);
		printf("%d\n",ans);
	}











   return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_45458915/article/details/108183144