【题目链接】
http://acm.hdu.edu.cn/showproblem.php?pid=6162
【算法】
离线树剖
我们知道,u到v路径上权值为[A,B]的数的和 = u到v路径上权值小于等于B的数的和 - u到v路径上权值小于等于(A-1)的数的和
不妨将询问拆开,离线计算答案即可
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 1e5 + 10; int i,n,m,u,v,now,timer,tot,cnt; long long l,r; int head[MAXN],size[MAXN],fa[MAXN],top[MAXN], dep[MAXN],son[MAXN],dfn[MAXN]; long long ans[MAXN]; struct Edge { int to,nxt; } e[MAXN<<1]; struct info { int val,pos; } a[MAXN]; struct Query { int u,v; long long m; int flag,id; } q[MAXN<<1]; struct SegmentTree { struct Node { int l,r; long long sum; } Tree[MAXN<<2]; inline void build(int index,int l,int r) { int mid; Tree[index] = (Node){l,r,0}; if (l == r) return; mid = (l + r) >> 1; build(index<<1,l,mid); build(index<<1|1,mid+1,r); } inline void update(int index) { Tree[index].sum = Tree[index<<1].sum + Tree[index<<1|1].sum; } inline void add(int index,int pos,long long val) { int mid; if (Tree[index].l == Tree[index].r) { Tree[index].sum += val; return; } mid = (Tree[index].l + Tree[index].r) >> 1; if (mid >= pos) add(index<<1,pos,val); else add(index<<1|1,pos,val); update(index); } inline long long query(int index,int l,int r) { int mid; if (Tree[index].l == l && Tree[index].r == r) return Tree[index].sum; mid = (Tree[index].l + Tree[index].r) >> 1; if (mid >= r) return query(index<<1,l,r); else if (mid + 1 <= l) return query(index<<1|1,l,r); else return query(index<<1,l,mid) + query(index<<1|1,mid+1,r); } } T; inline bool cmp1(info a,info b) { return a.val < b.val; } inline bool cmp2(Query a,Query b) { return a.m < b.m; } inline void add(int u,int v) { tot++; e[tot] = (Edge){v,head[u]}; head[u] = tot; } inline void dfs1(int u) { int i,v; size[u] = 1; for (i = head[u]; i; i = e[i].nxt) { v = e[i].to; if (fa[u] != v) { fa[v] = u; dep[v] = dep[u] + 1; dfs1(v); size[u] += size[v]; if (size[v] > size[son[u]]) son[u] = v; } } } inline void dfs2(int u,int tp) { int i,v; top[u] = tp; dfn[u] = ++timer; if (son[u]) dfs2(son[u],tp); for (i = head[u]; i; i = e[i].nxt) { v = e[i].to; if (fa[u] != v && son[u] != v) dfs2(v,v); } } inline long long query(int u,int v) { int tu = top[u], tv = top[v]; long long ret = 0; while (tu != tv) { if (dep[tu] > dep[tv]) { swap(u,v); swap(tu,tv); } ret += T.query(1,dfn[tv],dfn[v]); v = fa[tv]; tv = top[v]; } if (dfn[u] > dfn[v]) swap(u,v); ret += T.query(1,dfn[u],dfn[v]); return ret; } int main() { while (scanf("%d%d",&n,&m) != EOF) { timer = 0; tot = 0; for (i = 1; i <= n; i++) { head[i] = 0; son[i] = 0; } cnt = 0; memset(ans,0,sizeof(ans)); for (i = 1; i <= n; i++) { scanf("%lld",&a[i].val); a[i].pos = i; } for (i = 1; i < n; i++) { scanf("%d%d",&u,&v); add(u,v); add(v,u); } dfs1(1); dfs2(1,1); T.build(1,1,timer); for (i = 1; i <= m; i++) { scanf("%d%d%lld%lld",&u,&v,&l,&r); if (l > 1) q[++cnt] = (Query){u,v,l-1,1,i}; q[++cnt] = (Query){u,v,r,2,i}; } sort(a+1,a+n+1,cmp1); sort(q+1,q+cnt+1,cmp2); now = 1; for (i = 1; i <= cnt; i++) { while (now <= n && a[now].val <= q[i].m) { T.add(1,dfn[a[now].pos],a[now].val); now++; } if (q[i].flag == 1) ans[q[i].id] -= query(q[i].u,q[i].v); else ans[q[i].id] += query(q[i].u,q[i].v); } for (i = 1; i < m; i++) printf("%lld ",ans[i]); printf("%lld\n",ans[m]); } return 0; }