POJ 1741. Tree

Link
题意:
求树上长度不超过 \(k\) 的路径条数
思路:
点分治
对于根节点,树上路径的情况可分为经过根节点和不经过根节点两种情况,而经过根节点的路径又分为 \(x\) 到根节点和根节点到 \(y\) 两段
\(dist\) 数组记录的是节点到根节点的距离,并且我们将所有值全部记录在数组 \(a\),并按值从小到大排序
用两个指针 \(l,r\) 分别从前和从后开始扫描,我们可以发现当l向后扫描时,\(r\)一定是向前移动的,因此当 \(a[l]+a[r]<=k\) 时,\(ans+=r-l\)
注意:单条路径也是合法路径,因此开始时先将 \(0\) 记录在数组 \(a\)
扫描完成后我们将不合法的路径(路径的两个端点都属于根节点的同一子树,这两段路径有重复部分)也加在 \(ans\) 里了
我们可以利用容斥思想,对于每个子树,计算出不合法的路径条数:\(ans-=calc(x,d[x][p])(x \in son(p))\)
总而言之,整个点分治的过程为:
\(1.\) 以点 \(p\) 为根节点跟更新子树中的节点到点 \(p\) 的距离(\(p\)为重心)
\(2.\) 计算长度小于等于 \(k\) 的路径条数
\(3.\) 删除点 \(p\)
\(4.\) 对点 \(p\) 的所有子树执行 \(1 \sim 3\)
选取重心作为根节点,可使整个点分治的递归层数至多为 \(logn\)
代码:

#include<iostream>
#include<algorithm>
#include<cstring>

using namespace std;

const int N=10010;
const int M=20010;

int n,k;
int cnt,to[M],val[M],nxt[M],head[N];
bool st[N];
int dist[N],sze[N],all_node;
int a[N];
int cnt_dist;
int pos;
int res;
int ans;

void addedge(int u,int v,int w) {
    cnt++;
    to[cnt]=v;
    val[cnt]=w;
    nxt[cnt]=head[u];
    head[u]=cnt;
}
void get_root(int u,int pre) {
    sze[u]=1;
    int max_part=1;
    for(int i=head[u];i;i=nxt[i]) {
        int v=to[i];
        if(v==pre||st[v]) continue;
        get_root(v,u);  
        sze[u]+=sze[v];
        max_part=max(max_part,sze[v]);
    }
    max_part=max(max_part,all_node-sze[u]);
    if(max_part<res) {
        res=max_part;
        pos=u;
    }
}
void get_dist(int u,int pre) {
    a[++cnt_dist]=dist[u];
    for(int i=head[u];i;i=nxt[i]) {
        int v=to[i];
        if(v==pre||st[v]) continue;
        dist[v]=dist[u]+val[i];
        get_dist(v,u);
    }
}
int calc(int u,int t) {
    int sum=0;
    dist[u]=t;
    cnt_dist=0;
    get_dist(u,0);
    sort(a+1,a+1+cnt_dist);
    int l=1,r=cnt_dist;
    while(l<r) {
        if(a[l]+a[r]<=k) sum+=r-l,l++;
        else r--;
    }
    return sum;
}
void solve(int u) {
    st[u]=true;
    ans+=calc(u,0);
    for(int i=head[u];i;i=nxt[i]) {
        int v=to[i];
        if(st[v]) continue;
        ans-=calc(v,val[i]);
        res=n;
        all_node=sze[v];
        get_root(v,0);
        solve(pos);
    }
}
int main() {
    //freopen("in.txt","r",stdin);
    ios::sync_with_stdio(false);
    cin.tie(0);
    while(cin>>n>>k&&n) {
        cnt=0;
        memset(head,0,sizeof head);
        memset(st,false,sizeof st);
        ans=0;
        for(int i=1;i<n;i++) {
            int u,v,w;
            cin>>u>>v>>w;
            addedge(u,v,w);
            addedge(v,u,w);
        }
        res=n;
        all_node=n;
        get_root(1,0);
        solve(pos);
        cout<<ans<<endl;
    } 
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/c4Lnn/p/12501487.html