Link
题意:
求树上长度不超过 \(k\) 的路径条数
思路:
点分治
对于根节点,树上路径的情况可分为经过根节点和不经过根节点两种情况,而经过根节点的路径又分为 \(x\) 到根节点和根节点到 \(y\) 两段
记 \(dist\) 数组记录的是节点到根节点的距离,并且我们将所有值全部记录在数组 \(a\),并按值从小到大排序
用两个指针 \(l,r\) 分别从前和从后开始扫描,我们可以发现当l向后扫描时,\(r\)一定是向前移动的,因此当 \(a[l]+a[r]<=k\) 时,\(ans+=r-l\)
注意:单条路径也是合法路径,因此开始时先将 \(0\) 记录在数组 \(a\) 里
扫描完成后我们将不合法的路径(路径的两个端点都属于根节点的同一子树,这两段路径有重复部分)也加在 \(ans\) 里了
我们可以利用容斥思想,对于每个子树,计算出不合法的路径条数:\(ans-=calc(x,d[x][p])(x \in son(p))\)
总而言之,整个点分治的过程为:
\(1.\) 以点 \(p\) 为根节点跟更新子树中的节点到点 \(p\) 的距离(\(p\)为重心)
\(2.\) 计算长度小于等于 \(k\) 的路径条数
\(3.\) 删除点 \(p\)
\(4.\) 对点 \(p\) 的所有子树执行 \(1 \sim 3\)
选取重心作为根节点,可使整个点分治的递归层数至多为 \(logn\) 层
代码:
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=10010;
const int M=20010;
int n,k;
int cnt,to[M],val[M],nxt[M],head[N];
bool st[N];
int dist[N],sze[N],all_node;
int a[N];
int cnt_dist;
int pos;
int res;
int ans;
void addedge(int u,int v,int w) {
cnt++;
to[cnt]=v;
val[cnt]=w;
nxt[cnt]=head[u];
head[u]=cnt;
}
void get_root(int u,int pre) {
sze[u]=1;
int max_part=1;
for(int i=head[u];i;i=nxt[i]) {
int v=to[i];
if(v==pre||st[v]) continue;
get_root(v,u);
sze[u]+=sze[v];
max_part=max(max_part,sze[v]);
}
max_part=max(max_part,all_node-sze[u]);
if(max_part<res) {
res=max_part;
pos=u;
}
}
void get_dist(int u,int pre) {
a[++cnt_dist]=dist[u];
for(int i=head[u];i;i=nxt[i]) {
int v=to[i];
if(v==pre||st[v]) continue;
dist[v]=dist[u]+val[i];
get_dist(v,u);
}
}
int calc(int u,int t) {
int sum=0;
dist[u]=t;
cnt_dist=0;
get_dist(u,0);
sort(a+1,a+1+cnt_dist);
int l=1,r=cnt_dist;
while(l<r) {
if(a[l]+a[r]<=k) sum+=r-l,l++;
else r--;
}
return sum;
}
void solve(int u) {
st[u]=true;
ans+=calc(u,0);
for(int i=head[u];i;i=nxt[i]) {
int v=to[i];
if(st[v]) continue;
ans-=calc(v,val[i]);
res=n;
all_node=sze[v];
get_root(v,0);
solve(pos);
}
}
int main() {
//freopen("in.txt","r",stdin);
ios::sync_with_stdio(false);
cin.tie(0);
while(cin>>n>>k&&n) {
cnt=0;
memset(head,0,sizeof head);
memset(st,false,sizeof st);
ans=0;
for(int i=1;i<n;i++) {
int u,v,w;
cin>>u>>v>>w;
addedge(u,v,w);
addedge(v,u,w);
}
res=n;
all_node=n;
get_root(1,0);
solve(pos);
cout<<ans<<endl;
}
return 0;
}