点分治学习笔记(洛谷3806)

QwQ点分治这个东西 还是很有意思的

点分治主要是用来解决一些树上路径问题

首先,我们要明确点分治的分治标准是重心

什么是重心?

如果以 x 为根,所有子树的最大 s i z e 最小,那么称 x 是这棵树的重心

那么怎么去找重心呢

我们直接 d f s ,然后对于每个点,求一个 m x [ i ] 表示,以 i 为根的子树的最大的 s i z

int getroot(int x,int fa)
{
    siz[x]=1;
    mx[x]=0;
    for (int i=point[x];i;i=nxt[i])
    {
        int p = to[i];
        if (vis[p] || p==fa) continue;
        getroot(p,x);
        siz[x]+=siz[p];
        mx[x]=max(mx[x],siz[p]);
    }
    mx[x]=max(mx[x],n-siz[x]);
    if (mx[x]<mx[root]) root=x;
}

下面,我们来介绍点分治的过程

每次实际上就是重复这样的一个过程

每次找到当前子树的重心,然后求过这个重心的路径的贡献,然后容斥一下, 减去会被重复计算的贡献,然后再分别递归当前重心节点的所有子树,重复这个过程

于是每一次找到重心,递归的子树大小是不超过原树大小的一半的,那么递归层数不会超过 O ( l o g n ) 层,时间复杂度为 O ( n l o g n )

这个时间复杂度的分析,我是用调和剂数来做的,你考虑对于这个总的循环次数 应该 n + n 2 2 + n 4 4 ,如果把n提出来,后面自然就是个 l o g n ,所以总复杂度是 O ( n l o g n )

那么回到这个题目

我们可以统计出每个长度的路径条数,然后针对询问 O ( 1 ) 回答。

很显然的是,我们可以对每个重心开始 d f s ,然后两重循环枚举点,将 s u m [ d i s [ i ] + d i s [ j ] ] + + ,可以用上面同样的方法证明这个复杂度是 O ( n 2 l o g n )

那么这里就会出现不合法的路径,也就是一条边经过两个的那种,我们只需要把其他点的 d i s + l e n [ i ] ,这样再枚举点的时候,就强制默认了重复走了那条边,然后把 s u m [ d i s [ i ] + d i s [ j ] 就行

一些细节还是直接看代码吧

不过有要注意的地方就是,要把已经计算过的重心打上标记,这样不会在 d f s 的时候,走到其他子树里

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>

using namespace std;

inline int read()
{
   int x=0,f=1;char ch=getchar();
   while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
   while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
   return x*f;
}

const int maxn = 1e5+1e2;
const int maxm = maxn*3;

int point[maxn],nxt[maxm],to[maxm];
int dis[maxn],son[maxn],siz[maxn];
int root,n,m,cnt;
int val[maxn];
int num;
int mx[maxn];
int vis[maxn];
int sum[10000000];

void addedge(int x,int y,int w)
{
    nxt[++cnt]=point[x];
    to[cnt]=y;
    val[cnt]=w;
    point[x]=cnt;
}

void insert(int x,int y,int w)
{
    addedge(x,y,w);
    addedge(y,x,w); 
}

int getroot(int x,int fa)
{
    siz[x]=1;
    mx[x]=0;
    for (int i=point[x];i;i=nxt[i])
    {
        int p = to[i];
        if (vis[p] || p==fa) continue;
        getroot(p,x);
        siz[x]+=siz[p];
        mx[x]=max(mx[x],siz[p]);
    }
    mx[x]=max(mx[x],n-siz[x]);
    if (mx[x]<mx[root]) root=x;
}

void getdis(int x,int fa,int len)
{
    dis[++num]=len;
    for (int i=point[x];i;i=nxt[i])
    {
        int p = to[i];
        if (p==fa || vis[p]) continue;
        getdis(p,x,len+val[i]);
    }
}

void solve(int x,int len)
{
    num=0;
    getdis(x,0,len);
    if (len!=0)
    {
        for (int i=1;i<=num;i++)
          for (int j=i+1;j<=num;j++)
            sum[dis[i]+dis[j]]--;
    }
    else
    {
        for (int i=1;i<=num;i++)
          for (int j=i+1;j<=num;j++)
            sum[dis[i]+dis[j]]++;
    }
}
void dfs(int x)
{
    vis[x]=1;
    solve(x,0);
    for (int i=point[x];i;i=nxt[i]){
        int p = to[i];
        if (vis[p]) continue;
        solve(p,val[i]);
        root=0;
        n=siz[p];root=0;    
        getroot(p,0);
        //cout<<root<<" "<<sum[2]<<endl;
        dfs(root);
    }
} 

int main()
{
  n=read(),m=read();
  for (int i=1;i<n;i++) {
    int x,y,w;
    x=read(),y=read(),w=read();
    insert(x,y,w);
  }
  mx[root]=2e9;
  getroot(1,0);
  //cout<<root; 
  dfs(root);
  for (int i=1;i<=m;i++)
  {
    int x = read();
    if (sum[x]) cout<<"AYE"<<endl;
    else cout<<"NAY"<<endl; 
  }
  return 0;
}

猜你喜欢

转载自blog.csdn.net/y752742355/article/details/82154546
今日推荐