gmoj 6807. 【2020.10.29提高组模拟】tree

题目

https://gmoj.net/senior/#main/show/6807

题解

转化题意,可以发现,这道题就是选择一个根,使得它的某个子树内包含所有颜色,求满足条件的子树的最大深度。

比赛时我的思路是删掉以某个儿子为根的子树(或以当前点为根的子树外的部分),结果发现这样子处理不了删掉以孙子为根的子树的情况。最终挂在这道题上了。

其实应该考虑另外的处理方式:删掉某棵子树或删掉某棵子树外的全部点。这里的删掉指的是不选择这些点,若答案不等于1,根就在这些点中。

分别考虑这两种情况:

  1. 当我删掉一个子树时,说明这个子树外包含所有的颜色。但是这个条件比较难判断,于是考虑这个子树不能被删去时满足什么条件,显然是这个子树外缺少某种颜色,即这种颜色全都在这个子树内。对于每一种颜色,把它们的lca求出来,lca到根的路径上的点就是不能删掉子树的点,用倍增lca O ( n log ⁡ 2 n ) O(n\log_2n) O(nlog2n)地处理即可(但是这样子常数巨大,优化后面会讲);
  2. 当我删掉一个子树外的所有点时,说明这个子树内包含所有点。发现树上处理起来很麻烦(可以线段树合并,但是空间和时间都可能爆掉),就把它转化到序列上(按dfn序排序)。双指针 O ( n ) O(n) O(n)地扫描一下就行了(当然你喜欢的话也可以用主席树,但是可能会炸空间)。

理论上这样打就能过了,但是我常数太大TLE了……

发现跑得最慢的部分是求一堆点的lca那里,要不开#pragma GCC optimize("O3")过这题势必要优化这个部分。

这里有一个定理: ∀ d f n a ≤ d f n b ≤ d f n c , 都 有 l c a ( a , c ) = l c a ( a , b , c ) \forall dfn_a\le dfn_b\le dfn_c,都有 lca(a,c)=lca(a,b,c) dfnadfnbdfnc,lca(a,c)=lca(a,b,c)
证明的话就是 l c a ( a , c ) lca(a,c) lca(a,c)的子树中必定包含了dfn在 [ a , c ] [a,c] [a,c]中所有点,因此必定也是b的祖先。

有了这个定理就可以将这个部分优化到 O ( m log ⁡ 2 n ) O(m\log_2n) O(mlog2n)了,足以通过这道题。如果常数太大还是过不了,可以用tarjan lca

CODE

倍增lca版本,常数稍大:

#include<cstdio>
using namespace std;
#define M 2000005
#define N 1000005
#define C 100005
struct array{
    
    int fir[C],nex[N];}a;bool cover[N];
int fir[N],to[M],nex[M],col[N],las[C],b[C],right[N];
int f[N][20],g[N][2],son[N],h[N],dep[N],dfn[N],id[N],siz[N],cnt,s,m;
inline char gc()
{
    
    
	static char buf[100005],*l=buf,*r=buf;
	return l==r&&(r=(l=buf)+fread(buf,1,100005,stdin),l==r)?EOF:*l++;
}
inline void read(int &x)
{
    
    
	char ch;while(ch=gc(),ch<'0'||ch>'9');x=ch-48;
	while(ch=gc(),ch>='0'&&ch<='9') x=x*10+ch-48;
}
inline void inc(int x,int y)
{
    
    
	to[++s]=y,nex[s]=fir[x],fir[x]=s;
	to[++s]=x,nex[s]=fir[y],fir[y]=s;
}
inline void swap(int &x,int &y){
    
    int z=x;x=y,y=z;}
void dfs(int k)
{
    
    
	id[++cnt]=k,dfn[k]=cnt;
	dep[k]=dep[f[k][0]]+1,siz[k]=1;
	for(int i=fir[k];i;i=nex[i]) if(to[i]!=f[k][0])
		f[to[i]][0]=k,dfs(to[i]),siz[k]+=siz[to[i]];
}
inline int mymax(int x,int y){
    
    return x>y?x:y;}
inline int getlca(int u,int v)
{
    
    
	if(dep[u]<dep[v]) swap(u,v);
	for(int i=19;i>=0;--i)
		if(dep[f[u][0]]>=dep[v])
			u=f[u][0];
	if(u==v) return u;
	for(int i=19;i>=0;--i)
		if(f[u][i]^f[v][i])
			u=f[u][i],v=f[v][i];
	return f[u][0];
}
int main()
{
    
    
	freopen("tree.in","r",stdin);
	freopen("tree.out","w",stdout);
	int n,x,y,l,r,tot=0,ans=0;
	read(n),read(m);
	for(int i=1;i<=n;++i) read(col[i]),a.nex[i]=a.fir[col[i]],a.fir[col[i]]=i;
	for(int i=1;i<n;++i) read(x),read(y),inc(x,y);
	dep[1]=1,dfs(1);
	for(int i=n,tmp,k;i>1;--i)
	{
    
    
		tmp=g[id[i]][0]+1,k=f[id[i]][0];
		if(tmp>g[k][0]) g[k][1]=g[k][0],g[k][0]=tmp,son[k]=id[i];
		else if(tmp>g[k][1]) g[k][1]=tmp;
	}
	for(int i=2,k,fa;i<=n;++i) k=id[i],fa=f[k][0],h[k]=mymax(h[fa],g[fa][k==son[fa]])+1;
	for(int j=1;j<20;++j)
		for(int i=1;i<=n;++i)
			f[i][j]=f[f[i][j-1]][j-1];
	for(int i=1,max,min;i<=m;++i) if(a.fir[i])
	{
    
    
		max=0,min=N;
		for(int j=a.fir[i];j;j=a.nex[j])
		{
    
    
			if(dfn[j]>max) max=dfn[j];
			if(dfn[j]<min) min=dfn[j];
		}
		cover[getlca(id[max],id[min])]=1;
	}
	for(int i=n;i>1;--i) cover[f[id[i]][0]]|=cover[id[i]];
	for(int i=1;i<=m;++i) las[i]=n+1;
	for(int i=n;i;--i) right[i]=las[col[id[i]]],las[col[id[i]]]=i;
	for(int i=1;i<=n;++i) if(!cover[id[i]]) ans=mymax(ans,g[id[i]][0]+1);
	b[col[id[1]]]=1,tot=r=1;
	while(r<=n&&tot<m)
	{
    
    
		if(!b[col[id[++r]]]) ++tot;
		++b[col[id[r]]];
	}
	for(l=1;l<=n;++l)
	{
    
    
		if(r<=l+siz[id[l]]-1) ans=mymax(ans,h[id[l]]);
		if(!--b[col[id[l]]])
		{
    
    
			if(right[l]>n) break;
			for(int i=r+1;i<=right[l];++i) ++b[col[id[i]]];
			r=right[l];
		}
	}
	printf("%d\n",ans+1);
	return 0;
}

tarjan lca版本,代码稍长:

#include<cstdio>
using namespace std;
#define M 2000005
#define N 1000005
#define C 100005
struct array{
    
    int fir[C],nex[N];}a;bool cover[N];
struct query
{
    
    
	int fir[N],nex[200005],to[200005],s;
	inline void inc(int x,int y)
	{
    
    
		to[++s]=y,nex[s]=fir[x],fir[x]=s;
		to[++s]=x,nex[s]=fir[y],fir[y]=s;
	}
}qry;
int fir[N],to[M],nex[M],col[N],las[C],b[C],right[N];
int f[N],g[N][2],son[N],h[N],dep[N],dfn[N],id[N],siz[N],fa[N],cnt,s,m;
inline char gc()
{
    
    
	static char buf[100005],*l=buf,*r=buf;
	return l==r&&(r=(l=buf)+fread(buf,1,100005,stdin),l==r)?EOF:*l++;
}
inline void read(int &x)
{
    
    
	char ch;while(ch=gc(),ch<'0'||ch>'9');x=ch-48;
	while(ch=gc(),ch>='0'&&ch<='9') x=x*10+ch-48;
}
inline void inc(int x,int y)
{
    
    
	to[++s]=y,nex[s]=fir[x],fir[x]=s;
	to[++s]=x,nex[s]=fir[y],fir[y]=s;
}
inline void swap(int &x,int &y){
    
    int z=x;x=y,y=z;}
void dfs(int k)
{
    
    
	id[++cnt]=k,dfn[k]=cnt;
	dep[k]=dep[fa[k]]+1,siz[k]=1;
	for(int i=fir[k];i;i=nex[i]) if(to[i]!=fa[k])
		fa[to[i]]=k,dfs(to[i]),siz[k]+=siz[to[i]];
}
inline int mymax(int x,int y){
    
    return x>y?x:y;}
int getf(int k){
    
    return f[k]==k?k:f[k]=getf(f[k]);}
void getlca(int k)
{
    
    
	for(int i=fir[k];i;i=nex[i]) if(to[i]!=fa[k])
		getlca(to[i]),f[to[i]]=k;
	for(int i=qry.fir[k];i;i=qry.nex[i]) if(f[qry.to[i]]!=qry.to[i])
		cover[getf(qry.to[i])]=1;
}
int main()
{
    
    
	freopen("tree.in","r",stdin);
	freopen("tree.out","w",stdout);
	int n,x,y,l,r,tot=0,ans=0;
	read(n),read(m);
	for(int i=1;i<=n;++i) read(col[i]),a.nex[i]=a.fir[col[i]],a.fir[col[i]]=i;
	for(int i=1;i<n;++i) read(x),read(y),inc(x,y);
	dep[1]=1,dfs(1);
	for(int i=n,tmp,k;i>1;--i)
	{
    
    
		tmp=g[id[i]][0]+1,k=fa[id[i]];
		if(tmp>g[k][0]) g[k][1]=g[k][0],g[k][0]=tmp,son[k]=id[i];
		else if(tmp>g[k][1]) g[k][1]=tmp;
	}
	for(int i=2,k;i<=n;++i) k=id[i],h[k]=mymax(h[fa[k]],g[fa[k]][k==son[fa[k]]])+1;
	for(int i=1,max,min;i<=m;++i) if(a.fir[i])
	{
    
    
		max=0,min=N;
		for(int j=a.fir[i];j;j=a.nex[j])
		{
    
    
			if(dfn[j]>max) max=dfn[j];
			if(dfn[j]<min) min=dfn[j];
		}
		if(id[max]^id[min]) qry.inc(id[max],id[min]);
		else cover[id[max]]=1;
	}
	for(int i=1;i<=n;++i) f[i]=i;getlca(1);
	for(int i=n;i>1;--i) cover[fa[id[i]]]|=cover[id[i]];
	for(int i=1;i<=m;++i) las[i]=n+1;
	for(int i=n;i;--i) right[i]=las[col[id[i]]],las[col[id[i]]]=i;
	for(int i=1;i<=n;++i) if(!cover[id[i]]) ans=mymax(ans,g[id[i]][0]+1);
	b[col[id[1]]]=1,tot=r=1;
	while(r<=n&&tot<m)
	{
    
    
		if(!b[col[id[++r]]]) ++tot;
		++b[col[id[r]]];
	}
	for(l=1;l<=n;++l)
	{
    
    
		if(r<=l+siz[id[l]]-1) ans=mymax(ans,h[id[l]]);
		if(!--b[col[id[l]]])
		{
    
    
			if(right[l]>n) break;
			for(int i=r+1;i<=right[l];++i) ++b[col[id[i]]];
			r=right[l];
		}
	}
	printf("%d\n",ans+1);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/huangzihaoal/article/details/109411716