题目链接:https://nanti.jisuanke.com/t/38229
题意:给出一个n个点,有边权的树,求两点路径上边权小于等于k的边的数量
题解:先离散化下,用主席树维护下从上到下的边权,倍增求下LCA,最后查询下即可
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10;
struct node1{
int x,y,w;
}e[N];
struct edge{
int to,w;
edge(){}
edge(int to_,int w_)
{
to=to_;w=w_;
}
};
struct node
{
int l,r;
int val;
}tree[N*22];
vector<edge> v[N];
int val[N],len;
int n,q,deep[N];
int root[N],cnt;
int dp[N][22];
int build(int l,int r,int pre)
{
int cur=++cnt;
tree[cur]=tree[pre];
if(l==r)return cur;
int mid=(r+l)>>1;
tree[cur].l=build(l,mid,tree[pre].l);
tree[cur].r=build(mid+1,r,tree[pre].r);
return cur;
}
int update(int pos,int l,int r,int pre)
{
int cur=++cnt;
tree[cur]=tree[pre];
tree[cur].val++;
if(l==r) return cur;
int mid=(r+l)>>1;
if(pos<=mid) tree[cur].l=update(pos,l,mid,tree[pre].l);
else tree[cur].r=update(pos,mid+1,r,tree[pre].r);
return cur;
}
void dfs(int u,int fa)
{
deep[u]=deep[fa]+1;
dp[u][0]=fa;
for(int i=1;i<=20;i++)
{
if(dp[u][i-1])
dp[u][i]=dp[dp[u][i-1]][i-1];
else
break;
}
int w;
for(int i=0;i<v[u].size();i++)
{
int to=v[u][i].to;
if(to==fa) continue;
w=v[u][i].w;
// cout<<to<<" * "<<w<<endl;
root[to]=update(w,1,len,root[u]);
dfs(to,u);
}
}
int get_lca(int x,int y)
{
if(deep[x]<deep[y]) swap(x,y);
int tmp=deep[x]-deep[y];
for(int i=0;i<=20;i++)
if(tmp&(1<<i))
x=dp[x][i];
if(x==y) return x;
for(int i=20;i>=0;i--)
{
if(dp[x][i]!=dp[y][i])
{
x=dp[x][i];
y=dp[y][i];
}
}
return dp[x][0];
}
int query(int l,int r,int x,int y,int z,int val)
{
// cout<<l<<" "<<r<<" "<<tree[x].val<<" "<<tree[y].val<<" "<<tree[z].val<<endl;
if(r<=val)
{
return tree[x].val+tree[y].val-tree[z].val*2;
}
int res=0;
int mid=(r+l)>>1;
res+=query(l,mid,tree[x].l,tree[y].l,tree[z].l,val);
if(val>=mid+1) res+=query(mid+1,r,tree[x].r,tree[y].r,tree[z].r,val);
return res;
}
void init()
{
for(int i=1;i<=n;i++)
{
v[i].clear();
for(int j=0;j<=20;j++)
dp[i][j]=0;
}
}
int main()
{
while(~scanf("%d%d",&n,&q))
{
init();
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&e[i].x,&e[i].y,&e[i].w);
val[i]=e[i].w;
}
sort(val+1,val+n);
len=unique(val+1,val+n)-(val+1);
int x,y,w;
for(int i=1;i<n;i++)
{
x=e[i].x,y=e[i].y;
w=lower_bound(val+1,val+1+len,e[i].w)-val;
v[x].push_back(edge(y,w));
v[y].push_back(edge(x,w));
}
cnt=0;
root[0]=build(1,len,0);
dfs(1,0);
int fa,tmp,k;
while(q--)
{
scanf("%d%d%d",&x,&y,&k);
k=upper_bound(val+1,val+1+len,k)-(val);
k--;
if(k==0)
{
printf("0\n");
continue;
}
fa=get_lca(x,y);
printf("%d\n",query(1,len,root[x],root[y],root[fa],k));
}
}
return 0;
}