上厕所的时候想通了,果然厕所是一个思考的好地方
普通区间主席树和树上主席树的区别:
这两个其实只有建树方式不同而已,普通主席树在
for循环
里面建树,for循环就相当于线性的区间建树,而树上主席树是在树形结构上建主席树,建出来的主席树有树的性质,所以我们要求任意两点之间最短路径上的第k小,就需要用到lca,这里其实用到了树上点差分的思路,在树上建树其实就是在dfs()遍历树的时候我们将当前版本和上一版本进行复制,修改,而上一版本就是父节点
树上建树主席树代码:
void dfx(int u,int fa)
{
inser(1,n,root[fa],root[u],getid(a[u]));
for(int i=head[u];~i;i=edge[i].nex)
{
int v=edge[i].to;
if(v!=fa)
{
dfx(v,u);
}
}
}
AC代码:
#include<cstdio>
#include<algorithm>
#include<vector>
#include<cstring>
using namespace std;
//树上主席树
const int maxn=2e5+5;
int root[maxn],a[maxn],n,m,cn,cnt,cnx;
int head[maxn],depth[maxn],pre[maxn][32];
//离散化部分
vector<int>vec;
int getid(int x)
{
return lower_bound(vec.begin(),vec.end(),x)-vec.begin()+1;
}
//主席树部分
struct node
{
int l,r,sum;
} tr[maxn*40];
struct yzj
{
int to,nex;
} edge[maxn];
void inser(int l,int r,int ver,int &now,int pos)
{
now=++cnt;
tr[now]=tr[ver];
tr[now].sum++;
if(l==r)
{
return;
}
int mid=l+r>>1;
if(mid>=pos)
{
inser(l,mid,tr[ver].l,tr[now].l,pos);
}
else
{
inser(mid+1,r,tr[ver].r,tr[now].r,pos);
}
}
int query(int l,int r,int ver,int now,int lc,int lcfa,int k)
{
//printf("%d\n",l);
if(l==r)
{
return l;
}
int mid=l+r>>1;
int tem=tr[tr[ver].l].sum+tr[tr[now].l].sum-tr[tr[lc].l].sum-tr[tr[lcfa].l].sum;
if(tem>=k)
{
return query(l,mid,tr[ver].l,tr[now].l,tr[lc].l,tr[lcfa].l,k);
}
else
{
return query(mid+1,r,tr[ver].r,tr[now].r,tr[lc].r,tr[lcfa].r,k-tem);
}
}
//链式前向星
void add(int u,int v)
{
edge[cnx].to=v;
edge[cnx].nex=head[u];
head[u]=cnx++;
}
//lca部分
void dfs(int u,int fa)
{
inser(1,n,root[fa],root[u],getid(a[u]));
//printf("%d\n",getid(a[u]));
depth[u]=depth[fa]+1;
pre[u][0]=fa;
for(int i=1; (1<<i)<=depth[u]; i++)
pre[u][i]=pre[pre[u][i-1]][i-1];
for(int i=head[u]; ~i; i=edge[i].nex)
{
int v=edge[i].to;
if(v!=fa)
{
dfs(v,u);
}
}
}
int lca(int u,int v)
{
if(depth[u]<depth[v])
swap(u,v);
int i=-1,j;
while((1<<(i+1))<=depth[u])
{
i++;
}
for(j=i; j>=0; j--)
{
if(depth[u]-(1<<j)>=depth[v])
u=pre[u][j];
}
if(u==v)
{
return u;
}
for( j=i; j>=0; j--)
{
if(pre[u][j]!=pre[v][j])
{
u=pre[u][j];
v=pre[v][j];
}
}
return pre[u][0];
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d %d",&n,&m);
for(int i=1; i<=n; i++)
{
scanf("%d",&a[i]);
vec.push_back(a[i]);
}
sort(vec.begin(),vec.end());
vec.erase(unique(vec.begin(),vec.end()),vec.end());
//cn=vec.size();
for(int i=1; i<n; i++)
{
int u,v;
scanf("%d %d",&u,&v);
add(u,v);
add(v,u);
}
dfs(1,0);
// dfx(1,0);
int ans=0;
while(m--)
{
int u,v,k;
scanf("%d %d %d",&u,&v,&k);
u=u^ans;
//printf("%d\n",u);
int lc=lca(u,v);
int lcfa=pre[lc][0];
ans=query(1,n,root[u],root[v],root[lc],root[lcfa],k);
ans=vec[ans-1];
printf("%d\n",ans);
}
}