SP10707 COT2 - Count on a tree II 树上莫队

https://www.luogu.com.cn/problem/SP10707

欧拉序就是进去的时候标记一次出来的时候标记一次,然后我们把这个序列拿出来,变成一个2*n的新序列,每个点有前后两次访问ll[i]和rr[i],那么一段树上的路径就可以用区间表示

假设询问是u,v的路径,ll[u]<ll[v],如果u是v的祖先,那么直接对应到ll[u]-ll[v],否则对应到rr[u]-ll[v]再加上ll[lca(u,v)],因为u和v的最近公共祖先出现一定在ll[u]左边,第二次出现在rr[v]右边,所以是不包含在当前区间里的,然后注意如果这段欧拉序序列中某个位置出现了两次,那么他也是不在u,v路径上的

#include<bits/stdc++.h>
using namespace std;

const int maxl=1e5+10;

int n,m,tot,len;
int a[maxl],b[maxl],idx[maxl],bel[maxl],ans[maxl],dep[maxl];
int ll[maxl],rr[maxl],num[maxl],vis[maxl];
int f[21][maxl];
vector<int> e[maxl];
struct qry
{
	int l,r,lca,id;
}q[maxl];

inline void dfs(int u,int fa)
{
	idx[++n]=u;ll[u]=n;
	for(int v:e[u])
	if(v!=fa)
	{
		f[0][v]=u;dep[v]=dep[u]+1;
		dfs(v,u);
	}
	idx[++n]=u;rr[u]=n;
}

inline int getlca(int u,int v)
{
	if(dep[u]<dep[v])
		swap(u,v);
	for(int i=20;i>=0;i--)
	if((dep[u]-dep[v])>>i&1)
		u=f[i][u];
	if(u==v)
		return u;
	for(int i=20;i>=0;i--)
	if(f[i][u]!=f[i][v])
		u=f[i][u],v=f[i][v];
	return f[0][u];
}

inline bool cmp(const qry&a,const qry&b)
{
	return (bel[a.l]^bel[b.l])?bel[a.l]<bel[b.l]:a.r<b.r;
}

inline void prework()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
		scanf("%d",&a[i]),b[i]=a[i];
	sort(b+1,b+1+n);
	tot=unique(b+1,b+1+n)-b-1;
	for(int i=1;i<=n;i++)
		a[i]=lower_bound(b+1,b+1+tot,a[i])-b;
	int u,v;
	for(int i=1;i<=n-1;i++)
	{
		scanf("%d%d",&u,&v);
		e[u].push_back(v);
		e[v].push_back(u);
	}
	n=0;
	dfs(1,0);
	len=sqrt(n);
	for(int i=1;i<=n;i++)
		bel[i]=(i-1)/len+1;
	for(int k=1;k<=20;k++)
		for(int i=1;i<=n;i++)
			f[k][i]=f[k-1][f[k-1][i]];
	for(int i=1;i<=m;i++)
	{
		scanf("%d%d",&q[i].l,&q[i].r),q[i].id=i;
		if(ll[q[i].l]>ll[q[i].r])
			swap(q[i].l,q[i].r);
		q[i].lca=getlca(q[i].l,q[i].r);
		if(q[i].lca==q[i].l)
		{
			q[i].l=ll[q[i].l],q[i].r=ll[q[i].r];
			q[i].lca=0;
		}
		else
		{
			q[i].l=rr[q[i].l],q[i].r=ll[q[i].r];
			q[i].lca=ll[q[i].lca];
		}
	}
	sort(q+1,q+1+m,cmp);
}

inline int solv(int i)
{
	vis[idx[i]]^=1;
	if(vis[idx[i]])
		return !num[a[idx[i]]]++;
	else
		return -(!--num[a[idx[i]]]);
}

inline void mainwork()
{
	int l=1,r=0,now=0;
	for(int i=1;i<=m;i++)
	{
		while(r<q[i].r) 
			++r,now+=solv(r);
		while(l>q[i].l)
			--l,now+=solv(l);
		while(r>q[i].r)
			now+=solv(r),r--;
		while(l<q[i].l)
			now+=solv(l),l++;
		if(q[i].lca)
			now+=solv(q[i].lca);
		ans[q[i].id]=now;
		if(q[i].lca)
			now+=solv(q[i].lca);
	}
}

inline void print()
{
	for(int i=1;i<=m;i++)
		printf("%d\n",ans[i]);
}

int main()
{
	prework();
	mainwork();
	print();
	return 0;
}

猜你喜欢

转载自blog.csdn.net/liufengwei1/article/details/109173023
今日推荐