我们直接把可以的dis存在数组里,注意solve儿子时先剪掉贡献,具体看代码吧
#include<bits/stdc++.h>
#define N 40005
using namespace std;
int first[N],next[N],to[N],w[N],tot;
int n,m,siz,vis[N],ret,K[N];
int size[N],rt,Maxson[N],tmp[N];
int read(){
int cnt=0;char ch=0;
while(!isdigit(ch))ch=getchar();
while(isdigit(ch))cnt=cnt*10+(ch-'0'),ch=getchar();
return cnt;
}
void add(int x,int y,int z){
next[++tot]=first[x],first[x]=tot,to[tot]=y,w[tot]=z;
}
void Getroot(int u,int f){
size[u]=1;
for(int i=first[u];i;i=next[i]){
int t=to[i]; if(t==f||vis[t]) continue;
Getroot(t,u); size[u]+=size[t];
Maxson[u] = max(Maxson[u],size[t]);
}
Maxson[u] = max(Maxson[u] , siz-Maxson[u]);
if(Maxson[rt] > Maxson[u]) rt=u;
}
void Getdis(int u,int f,int dis){
tmp[++ret] = dis;
for(int i=first[u];i;i=next[i]){
int t=to[i]; if(t==f||vis[t]) continue;
Getdis(t,u,dis+w[i]);
}
}
void calc(int x,int pre_dis){
ret=0,Getdis(x,0,0);
for(int i=1;i<=ret;i++)
for(int j=i+1;j<=ret;j++)
if(!pre_dis) K[tmp[i]+tmp[j]]++;
else K[tmp[i]+tmp[j]+pre_dis]--;
}
void Solve(int x){
calc(x,0) , vis[x]=1;
for(int i=first[x];i;i=next[i]){
int t=to[i]; if(vis[t]) continue;
calc(t,w[i]*2) , rt=0 , siz=size[t];
Getroot(t,0) , Solve(rt);
}
}
int main(){
n=read(),m=read();
for(int i=1;i<n;i++){
int x=read(),y=read(),z=read();
add(x,y,z),add(y,x,z);
}
siz = Maxson[0] = n;
Getroot(1,0) , Solve(rt);
for(int i=1;i<=m;i++){
int x=read();
if(K[x]) printf("AYE\n");
else printf("NAY\n");
}
return 0;
}