虚树DP

参考博客

1、作用

把一些特殊的点以及他们的\(LCA\)节点拿出来建树

2、构建

跑一遍\(dfs\)序:
\(st:dfs\)到达每个时间的时间戳
\(ed:dfs\)离开每个时间的时间戳

  1. 对特殊点的数组按\(st\)排序
  2. 对排序后的数组求相邻节点的\(LCA\),插入数组
  3. 对特殊点的数组按\(st\)排序
  4. 初始化空树边(可以后面跑\(dfs\)边跑边删)
  5. 使用栈利用特殊点数组的\(st, ed\)进行建边

\(dfs\)

void dfs1(int u, int f, int d){		
    dep[u] = d;		//记录根节点
    st[u] = ++cnt;	//记录时间戳
    for(auto v: e1[u]){
        if(v == f) continue;
        dfs1(v, u, d+1);
        fa[v][0] = u;
    }
    ed[u] = cnt;
}

建树

bool cmp(int x, int y){
    return st[x] < st[y];
}
void build(int d){
    stack<int> sta;
    while(!sta.empty()) sta.pop();
    int len = vec.size(), f;
    sort(vec.begin(), vec.end(), cmp);	\\根据st排序
    for(int i = 1; i < len; i++){
        f = LCA(vec[i], vec[i-1]);
        vec.pb(f);
    }
    vec.pb(s);		\\根节点(根据题目判断要不要加)
    sort(vec.begin(), vec.end(), cmp);
    vec.erase(unique(vec.begin(), vec.end()), vec.end());
    int now;
    for(auto i: vec){	\\建边
        if(!sta.empty()){
            now = sta.top();
            while(!sta.empty() && ed[now] < st[i]){
                sta.pop();
                now = sta.top();
            }
            e[now].pb(i);
        }
        sta.push(i);
    }
}

“科大讯飞杯”第18届上海大学程序设计联赛春季赛暨高校网络友谊赛-G血压游戏

#include<bits/stdc++.h>
#define mes(a, b) memset(a, b, sizeof a)
#define pb push_back
using namespace std;
typedef long long ll;
const int mod = 1e9+7;
const int maxn = 1e6+10;
const int pi = acos(-1);
int n, m, s;
int st[maxn], ed[maxn], dep[maxn], num[maxn], fa[maxn][30], cnt = 0;
ll a[maxn], ans;
map<int, ll> siz;
vector<int> e1[maxn],e[maxn], vec;

void dfs1(int u, int f, int d){
    dep[u] = d;
    st[u] = ++cnt;
    for(auto v: e1[u]){
        if(v == f) continue;
        dfs1(v, u, d+1);
        fa[v][0] = u;
    }
    ed[u] = cnt;
}
void init(){
    for(int j = 1; j <= 20; j++){
        for(int i = 1; i <= n; i++){
            fa[i][j] = fa[fa[i][j-1]][j-1];
        }
    }
}
int LCA(int u, int v){
    if(dep[u] < dep[v]) swap(u, v);
    int f = dep[u]-dep[v];
    for(int i = 0; i <= 20; i++){
        if(f & (1<<i))
            u = fa[u][i];
    }
    if(u == v) return u;
    for(int i = 20; i >= 0; i--){
        if(fa[u][i] != fa[v][i]){
            u = fa[u][i];
            v = fa[v][i];
        }
    }
    return fa[u][0];
}
void dfs(int u, int d){

    if(dep[u] == d)
        siz[u] = a[u];
    for(auto v: e[u]){
        dfs(v, d);
        if(siz[v])
            siz[u] += max(siz[v]-(dep[v]-dep[u]), 1ll);
    }
    e[u].clear();
}


bool cmp1(int x, int y){
    return dep[x]<dep[y];
}
bool cmp(int x, int y){
    return st[x] < st[y];
}
void build(int d){
    stack<int> sta;
    while(!sta.empty()) sta.pop();
    int len = vec.size(), f;
    sort(vec.begin(), vec.end(), cmp);
    for(int i = 1; i < len; i++){
        f = LCA(vec[i], vec[i-1]);
        vec.pb(f);
    }
    vec.pb(s);
    sort(vec.begin(), vec.end(), cmp);
    vec.erase(unique(vec.begin(), vec.end()), vec.end());
    int now;
    for(auto i: vec){
        if(!sta.empty()){
            now = sta.top();
            while(!sta.empty() && ed[now] < st[i]){
                sta.pop();
                now = sta.top();
            }
            e[now].pb(i);
        }
        sta.push(i);
    }
    siz.clear();
    dfs(s, d);
    if(siz[s])
        ans += max(siz[s]-1ll, 1ll);
}
int main(){
    int x, y;
    scanf("%d%d", &n, &s);
    for(int i = 1; i <= n; i++) scanf("%lld", a+i), num[i] = i;
    for(int i = 1; i < n; i++){
        scanf("%d%d", &x, &y);
        e1[x].pb(y);
        e1[y].pb(x);
    }
    cnt = ans = 0;
    dfs1(s, s, 1);
    init();
    sort(num+1, num+1+n, cmp1);
    int now = 1;
    for(int i = 1; i <= n&&now <= n; i++){
        vec.clear();
        printf("i = %d\n", i);
        while(now <= n && dep[num[now]] == i)
            vec.pb(num[now++]);
        if(vec.empty()) continue;

        build(i);
    }
    printf("%lld\n", ans);

}

猜你喜欢

转载自www.cnblogs.com/zhuyou/p/12732473.html