[BZOJ 2588] Count on a tree

[题目链接]

           https://www.lydsy.com/JudgeOnline/problem.php?id=2588

[算法]

         如果我们能知道“u到v这条路径上权值<= k的数的个数” ,  那么就可以通过二分的方式求出答案

         进一步地 , u到v路径上权值<= k的数的个数 = u到根节点路径上权值<= k的数的个数 + v到根节点路径上权值<= k的数的个数 - u和v的最近公共祖先到根节点路径上权值<= k的数的个数 - u和v的最近公共祖先的父节点到根节点路径上权值<= k的数的个数 

         建立可持久化线段树 , 查询时在线段树上二分即可

         时间复杂度 : O(NlogN)

[代码]

        

#include<bits/stdc++.h>
using namespace std;
#define MAXN 100010
#define MAXLOG 20
typedef long long ll;
typedef long double ld;

struct edge
{
        int to , nxt;
} e[MAXN << 1];

int n , q , len , idx , tot;
int lson[MAXN * 20] , rson[MAXN * 20] , sum[MAXN * 20] , depth[MAXN] , head[MAXN] , root[MAXN] , son[MAXN] , size[MAXN] , fa[MAXN] , top[MAXN];
int a[MAXN] , tmp[MAXN];
int up[MAXN][MAXLOG];

template <typename T> inline void chkmax(T &x,T y) { x = max(x,y); }
template <typename T> inline void chkmin(T &x,T y) { x = min(x,y); }
template <typename T> inline void read(T &x)
{
    T f = 1; x = 0;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
    for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
    x *= f;
}
inline void addedge(int u , int v)
{
        ++tot;
        e[tot] = (edge){v , head[u]};
        head[u] = tot;
}
inline void build(int &k , int l , int r)
{
    k = ++idx;
    if (l == r) return;
    int mid = (l + r) >> 1;
    build(lson[k] , l , mid);
    build(rson[k] , mid + 1 , r);
}
inline void modify(int &k , int old , int l , int r , int pos , int value)
{
    k = ++idx;
    lson[k] = lson[old] , rson[k] = rson[old];
    sum[k] = sum[old] + value;
    if (l == r) return;
    int mid = (l + r) >> 1;
    if (mid >= pos) modify(lson[k] , lson[k] , l , mid , pos , value);
    else modify(rson[k] , rson[k] , mid + 1 , r , pos , value);        
}
inline int query(int rt1 , int rt2 , int rt3 , int rt4 , int l , int r , int k)
{
        if (l == r) return l;
        int value = sum[lson[rt1]] + sum[lson[rt2]] - sum[lson[rt3]] - sum[lson[rt4]];
        int mid = (l + r) >> 1;
        if (value >= k) return query(lson[rt1] , lson[rt2] , lson[rt3] , lson[rt4] , l , mid , k);
        else return query(rson[rt1] , rson[rt2] , rson[rt3] , rson[rt4] , mid + 1 , r , k - value);
}
inline void dfs1(int u , int father)
{
        size[u] = 1;
        depth[u] = depth[father] + 1;
        modify(root[u] , root[father] , 1 , len , a[u] , 1);
        for (int i = head[u]; i; i = e[i].nxt)
        {
                int v = e[i].to;
                if (v == father) continue;
                dfs1(v , u);
                fa[v] = u;
                size[u] += size[v];
                if (son[u] == 0 || size[v] > size[son[u]]) son[u] = v; 
        }
}
inline void dfs2(int u , int tp)
{
        top[u] = tp;
        if (son[u]) dfs2(son[u] , tp);
        for (int i = head[u]; i; i = e[i].nxt)
        {
                int v = e[i].to;
                if (v != son[u] && v != fa[u]) dfs2(v , v);
        }
}
inline int lca(int x , int y)
{
        while (top[x] != top[y])
        {
                if (depth[top[x]] > depth[top[y]]) swap(x , y);
                y = fa[top[y]];
        }
        if (depth[x] < depth[y]) return x;
        else return y;
}
inline int getans(int u , int v , int k)
{
        int t1 = lca(u , v) , t2 = fa[t1];
        return query(root[u] , root[v] , root[t1] , root[t2] , 1 , len , k);
}

int main()
{
        
        read(n); read(q);
        for (int i = 1; i <= n; i++) 
        {
                read(a[i]);
                tmp[i] = a[i];
        }
        for (int i = 1; i < n; i++)
        {
                int x , y;
                read(x); read(y);
                addedge(x , y);
                addedge(y , x);
        }
        sort(tmp + 1 , tmp + n + 1);
        len = unique(tmp + 1 , tmp + n + 1) - tmp - 1;
        for (int i = 1; i <= n; i++) a[i] = lower_bound(tmp + 1 , tmp + len + 1 , a[i]) - tmp;
        build(root[0] , 1 , len);
        dfs1(1 , 0);
        dfs2(1 , 1);
        int lastans = 0;
        while (q--)
        {
                int u , v , k;
                read(u); read(v); read(k);
                u ^= lastans;
                printf("%d\n" , lastans = tmp[getans(u , v , k)]);
        }
        
        return 0;
    
}

猜你喜欢

转载自www.cnblogs.com/evenbao/p/10046757.html
今日推荐