bzoj2152 树分治

还是太菜了,自己写的wa,但是找不到哪里错了,,

感觉现在学树分治早了点。。以后回来再看吧

/*
多少点对之间的路径是3的倍数
*/
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
#include<algorithm>
#define MAXN 20010
int N;
struct E{
    int v,next,w;
}edge[MAXN<<1];
int head[MAXN],tot;
int size[MAXN];
int maxv[MAXN];
int vis[MAXN];
int dis[MAXN];
int num,ans,Max,root;
void init(){
    tot=ans=0;
    memset(head,-1,sizeof head);
    memset(vis,0,sizeof vis);
}
void addedge(int u,int v,int w){
    edge[tot].v=v;
    edge[tot].w=w;
    edge[tot].next=head[u];
    head[u]=tot++;
}
//一次dfs处理子树的大小
void dfssize(int u,int f){
    size[u]=1;
    maxv[u]=0;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].v;
        if(v==f||vis[v]) continue;
        dfssize(v,u);
        size[u]+=size[v];
        maxv[u]=max(maxv[u],size[v]);
    }
}
//一次dfs找重心
void dfsroot(int r,int u,int f){
    maxv[u]=max(maxv[u],size[r]-maxv[u]);
    if(maxv[u]<Max)
        Max=maxv[u],root=u;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].v;
        if(v==f||vis[v])
            continue;
        dfsroot(r,v,u);
    }
}
int tmp[3];
//一次dfs求路径长度
void dfsdis(int u,int d,int f){
    dis[u]=d%3;
    tmp[dis[u]]++;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].v;
        if(v==f||vis[v])
            continue;
        dfsdis(v,d+edge[i].w,u);
    }
}
//计算以u为根的子树中有多少点对的距离%3==0
int calc(int u,int d){
    tmp[0]=tmp[1]=tmp[2]=0;
    dfsdis(u,d,-1);//得到dis数组
    return tmp[0]*tmp[0]+tmp[1]*tmp[2]*2;
}
//分治
void dfs(int u){
    Max=N;
    dfssize(u,-1);
    dfsroot(u,u,-1);
    ans+=calc(root,0);
    vis[root]=1;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].v;
        if(vis[v])
            continue;
        ans-=calc(v,edge[i].w);
        dfs(v);
    }
}

int gcd(int a,int b){
    if(b==0) return a;
    else return gcd(b,a%b);
}
int main(){
    scanf("%d",&N);
    init();
    for(int i=1;i<=N-1;i++){
        int u,v,w;
        scanf("%d%d%d",&u,&v,&w);
        addedge(u,v,w%3);
        addedge(v,u,w%3);
    }
    dfs(1);
    int g=gcd(ans,N*N);
    printf("%d/%d",ans/g,N*N/g);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/zsben991126/p/9876362.html