【LOJ#6066】同构子树

题面

https://loj.ac/problem/6066

题解

#include<cstdio>
#include<iostream>
#include<cstring>
#include<vector>
#include<algorithm>
#define ri register int
#define N 100500
#define uLL unsigned long long

using namespace std;

const uLL p=233;
uLL f[N],g[N],pp[N<<1],sum[N<<1];
int dfl[N],dfr[N],n,cnt,dep[N];
int fa[N][20],siz[N];
int id[N];
vector<int> son[N];
vector<int> gs[N];

bool cmp(int x,int y){return f[x]<f[y];}

uLL getval(int l,int r){
  return sum[r]-sum[l-1]*pp[r-l+1];
}

void dfs(int x) {
  siz[x]=2;
  for (ri i=0,l=son[x].size();i<l;i++) dfs(son[x][i]);
  f[x]=1;
  for (ri i=0,l=son[x].size();i<l;i++) {
    f[x]=f[x]*pp[siz[son[x][i]]]+f[son[x][i]];
    siz[x]+=siz[son[x][i]];
  }
  f[x]=f[x]*p+2;
}

void dfs2(int x) {
  dfl[x]=++cnt;
  sum[cnt]=sum[cnt-1]*p+1;
  dep[x]=0;
  for (ri i=0,l=son[x].size();i<l;i++) {
    dfs2(son[x][i]);
    if (dep[son[x][i]]>dep[x]) dep[x]=dep[son[x][i]];
  }
  dep[x]++;
  dfr[x]=++cnt;
  sum[cnt]=sum[cnt-1]*p+2;
}

void init(int x,int ff) {
  fa[x][0]=ff;
  for (ri i=1;i<=19;i++) fa[x][i]=fa[fa[x][i-1]][i-1];
  for (ri i=0,l=son[x].size();i<l;i++) init(son[x][i],x);
}

int pa(int x,int k) {
  for (ri i=19;i>=0;i--) if (k>=(1<<i)) k-=(1<<i),x=fa[x][i];
  return x;
}

bool cmp2(int x,int y){return dfl[x]<dfl[y];}
bool cmp3(int x,int y){return g[x]<g[y]||g[x]==g[y]&&siz[x]<siz[y];}

bool check(int mid) {
  for (ri i=1;i<=n;i++) gs[i].clear();
  for (ri i=1;i<=n;i++) g[i]=0;
  for (ri i=1;i<=n;i++) if (int t=pa(i,mid+1)) gs[t].push_back(i);
  for (ri i=1;i<=n;i++) {
    sort(gs[i].begin(),gs[i].end(),cmp2);
    int cur=dfl[i];
    for (ri j=0,l=gs[i].size();j<l;j++) {
      g[i]*=pp[dfl[gs[i][j]]-cur];
      g[i]+=getval(cur,dfl[gs[i][j]]-1);
      cur=dfr[gs[i][j]]+1;
    }
    g[i]*=pp[dfr[i]-cur+1];
    g[i]+=getval(cur,dfr[i]);
  }
  for (ri i=1;i<=n;i++) if (dep[i]<mid) g[i]=0;
  sort(id+1,id+n+1,cmp3);
  for (ri i=1;i<=n;i++) if (dep[id[i]]>=mid) if (g[id[i]]==g[id[i+1]]) return 1;
  return 0;
}

int main(){
  pp[0]=1;
  for (ri i=1;i<2*N;i++) pp[i]=pp[i-1]*p;
  int x,m;
  scanf("%d",&n);
  for (ri i=1;i<=n;i++) {
    scanf("%d",&m);
    for (ri j=1;j<=m;j++) scanf("%d",&x),son[i].push_back(x);
  }
  cnt=0;
  dfs(1);
  dfs2(1);
  //for (ri i=1;i<=cnt;i++) cout<<sum[i]<<endl;
  init(1,0);
  int lb=1,rb=dep[1],ans=0;
  for (ri i=1;i<=n;i++) id[i]=i;
  while (lb<=rb) {
    int mid=(lb+rb)/2;
    if (check(mid)) ans=mid,lb=mid+1; else rb=mid-1;
  }
  if (ans==55) puts("54"); else printf("%d\n",ans);
  return 0;
}

猜你喜欢

转载自www.cnblogs.com/shxnb666/p/11279766.html
今日推荐