https://ac.nowcoder.com/acm/contest/5669/A
这题主要是要想到K为定值时,如何找出最短距离
于是我们可以二分最短距离x,然后再这棵树中,每次找到一个深度最深的点,然后从这个点向上走x到祖先anc,然后把anc标记为key点,并把anc所在的子树删除,重复这个过程直到把1拿掉,那么这就是最小的数量。删除子树,树上dfs序建线段树经典题,子树在dfs序上是连续的编号,于是删除就是区间覆盖
想到这个最小的数量就很好求了
我们直接枚举最短距离从1->mx,0的时候肯定是n个点都要放满
枚举完最短距离i后重复上述贪心过程,不停地删子树,最后得到如果要让最短距离为i的最少的点的个数cnt,令ans[cnt]=min(ans[cnt],i),因为有可能不同i算出来的cnt一样
注意区间覆盖,我们用一个tag来标记这个区间已经被删除了,并且如果一个点左儿子和右儿子都被删除了,他自己也被删除了,然而最后由于要还原线段树,我们记录一个原始的ini[k]=tree[k].id,然后重复一遍删除的操作,把每个访问的点变成Ini[k],而且删除标记也为0,就相当于还原了
由于对距离i来说,最多标记的点数是n/i,所以枚举所有的距离,操作次数最多是nlnn,所以总复杂度是n*lnn*logn
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxl=1e6+2e5+10;
int n,m,cas,k,cnt,tot,ind,mx,up;
int ans[maxl],dep[maxl],l[maxl],r[maxl],dy[maxl];
int tmp[maxl],ini[maxl*4];
int f[23][maxl];
vector<int> e[maxl];
bool in[maxl];
struct node
{
int l,r,id,tag;
}tree[maxl*4];
inline void dfs(int u)
{
l[u]=++ind;dy[ind]=u;
for(int v:e[u])
{
f[0][v]=u;
dfs(v);
}
r[u]=ind;
}
inline void pushup(int k)
{
if(tree[k].l==tree[k].r || tree[k].tag)
return;
int id;
if(tree[k<<1].tag)
{
if(tree[k<<1|1].tag)
{
tree[k].tag=true;
return;
}
id=tree[k<<1|1].id;
}
else
{
id=tree[k<<1].id;
if(!tree[k<<1|1].tag && dep[dy[id]]<dep[dy[tree[k<<1|1].id]])
id=tree[k<<1|1].id;
}
tree[k].id=id;
}
inline void build(int k,int l,int r)
{
tree[k].l=l;tree[k].r=r;tree[k].tag=0;
if(l==r)
{
tree[k].id=l;ini[k]=l;
return;
}
int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
pushup(k);ini[k]=tree[k].id;
}
inline void prework()
{
for(int i=1;i<=n;i++)
dep[i]=0,e[i].clear();
int fa;dep[1]=0;mx=0;
for(int i=2;i<=n;i++)
{
scanf("%d",&fa);
dep[i]=dep[fa]+1;mx=max(dep[i],mx);
e[fa].push_back(i);
}
ind=0;
dfs(1);up=log2(mx);
for(int k=1;k<=up;k++)
for(int i=1;i<=n;i++)
f[k][i]=f[k-1][f[k-1][i]];
build(1,1,n);
}
inline int findf(int x,int l)
{
for(int k=up;k>=0;k--)
if(l&(1<<k))
x=f[k][x];
if(!x) x=1;
return x;
}
inline void upd(int k,int l,int r,bool x)
{
if(!x) tree[k].tag=x,tree[k].id=ini[k];
if(tree[k].l==l && tree[k].r==r)
{
tree[k].tag=x;
return;
}
int mid=(tree[k].l+tree[k].r)>>1;
if(r<=mid)
upd(k<<1,l,r,x);
else if(l>mid)
upd(k<<1|1,l,r,x);
else
{
upd(k<<1,l,mid,x);
upd(k<<1|1,mid+1,r,x);
}
pushup(k);
}
inline void mainwork()
{
for(int i=1;i<=n;i++)
ans[i]=n+1;
ans[n]=0;int anc,id;
for(int i=1;i<=mx;i++)
{
cnt=0;
while(1)
{
id=dy[tree[1].id];
anc=findf(id,i);
if(anc==1)
break;
upd(1,l[anc],r[anc],1);
tmp[++cnt]=anc;
}
ans[cnt+1]=min(ans[cnt+1],i);
for(int i=cnt;i>=1;i--)
upd(1,l[tmp[i]],r[tmp[i]],0);
}
}
inline void print()
{
ll sum=ans[1];
for(int i=2;i<=n;i++)
{
ans[i]=min(ans[i-1],ans[i]);
sum+=ans[i];
}
printf("%lld\n",sum);
}
int main()
{
//freopen("in.in","r",stdin);
int t=1;
while(~scanf("%d",&n))
{
prework();
mainwork();
print();
}
return 0;
}