原题传送门
注意
可以倍增的同时,暴力记录每个点往上跳
步范围内前10小的点集
倍增跳的时候暴力合并
Code:
#include <bits/stdc++.h>
#define maxn 100010
using namespace std;
struct Edge{
int to, next;
}edge[maxn << 1];
struct data{
int num[15], tot;
}ans, node[maxn][25];
int num, head[maxn], n, m, Q, d[maxn], fa[maxn][25];
inline int read(){
int s = 0, w = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
return s * w;
}
void addedge(int x, int y){ edge[++num] = (Edge){y, head[x]}, head[x] = num; }
void merge(data &a, data b){
data c;
int i = 1, j = 1;
c.tot = 0;
for (; c.tot <= 10; ++c.tot){
if (i <= a.tot && (j > b.tot || a.num[i] < b.num[j])) c.num[c.tot + 1] = a.num[i++];
else if (j <= b.tot) c.num[c.tot + 1] = b.num[j++]; else break;
}
a.tot = c.tot;
for (int i = 1; i <= c.tot; ++i) a.num[i] = c.num[i];
}
void build(int u, int pre){
d[u] = d[pre] + 1, fa[u][0] = pre;
for (int i = 0; fa[u][i]; ++i) fa[u][i + 1] = fa[fa[u][i]][i], merge(node[u][i + 1], node[u][i]), merge(node[u][i + 1], node[fa[u][i]][i]);
for (int i = head[u]; i; i = edge[i].next){
int v = edge[i].to;
if (v != pre) build(v, u);
}
}
int main(){
n = read(), m = read(), Q = read();
for (int i = 1; i < n; ++i){
int x = read(), y = read();
addedge(x, y), addedge(y, x);
}
for (int i = 1; i <= m; ++i){
int x = read();
data tmp;
tmp.tot = 1, tmp.num[1] = i;
merge(node[x][0], tmp);
}
/* for (int i = 1; i <= n; ++i){
printf("%d\n", node[i][0].tot);
for (int j = 1; j <= node[i][0].tot; ++j) printf("%d ", node[i][0].num[j]);
puts("m");
}*/
build(1, 0);
while (Q--){
int u = read(), v = read(), a = read();
if (d[u] < d[v]) swap(u, v);
ans.tot = 0;
for (int i = 20; i >= 0; --i)
if (d[u] - (1 << i) >= d[v]) merge(ans, node[u][i]), u = fa[u][i];
if (u != v){
for (int i = 20; i >= 0; --i)
if (fa[u][i] != fa[v][i])
merge(ans, node[u][i]), merge(ans, node[v][i]),
u = fa[u][i], v = fa[v][i];
merge(ans, node[u][0]), merge(ans, node[v][0]), u = fa[u][0], v = fa[v][0];
}
merge(ans, node[u][0]);
ans.tot = min(ans.tot, a);
printf("%d ", ans.tot);
for (int i = 1; i <= ans.tot; ++i) printf("%d ", ans.num[i]);
puts("");
}
return 0;
}