题解:
树链剖分快速求解任意两点间的路径的权值和;
然后,二分答案;
此题的难点是如何快速求解重合路径?
差分数组可以否???
在此之前先介绍一下相关变量:
1 int fa[maxn]; 2 int siz[maxn];//siz[i]:i子树的节点个数 3 int dep[maxn];//dep[i]:节点i在树中的深度 4 int son[maxn];//son[i]:节点i的重儿子 5 int w[maxn];//w[i]:i节点与其父节点的权值 6 int tid[maxn];//tid[i]:节点i的新编号 7 int top[maxn];//top[i]:节点i所在重链的祖先 8 int s[maxn];//s[i]:新编号中,[1,i]的权值之和,s[i]=w[1]+w[2]+...+w[i];
如何用差分数组求解重合路径呢?
定义diff[ ]为差分数组。
1 int diff[maxn]; 2 void Update(int u,int v) 3 { 4 int x=top[u]; 5 int y=top[v]; 6 if(x == y)//如果 u,v 在同一条重链上 7 { 8 if(dep[u] > dep[v]) 9 swap(u,v); 10 //节点son[u]到节点v的新编号是连续的 11 diff[tid[son[u]]]++; 12 diff[tid[v]+1]--; 13 return ; 14 } 15 else//如果不在 16 { 17 if(dep[x] > dep[y]) 18 { 19 swap(x,y); 20 swap(u,v); 21 } 22 //节点 top[v]与节点v的新编号是连续的 23 diff[tid[y]]++; 24 diff[tid[v]+1]--; 25 Update(u,fa[y]); 26 } 27 }
AC代码:
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<algorithm> 5 using namespace std; 6 #define ls(x) (x<<1) 7 #define rs(x) (x<<1|1) 8 #define mem(a,b) memset(a,b,sizeof(a)) 9 const int maxn=3e5+50; 10 11 int n,m; 12 int fa[maxn]; 13 int siz[maxn];//siz[i]:i子树的节点个数 14 int dep[maxn]; 15 int son[maxn];//son[i]:节点i的重儿子 16 int w[maxn];//w[i]:i节点与其父节点的权值 17 int tid[maxn];//tid[i]:节点i的新编号 18 int top[maxn];//top[i]:节点i所在重链的祖先 19 int s[maxn];//s[i]:新编号中,[1,i]的权值之和,s[i]=w[1]+w[2]+...+w[i]; 20 int num; 21 int head[maxn]; 22 struct Edge 23 { 24 int to; 25 int w; 26 int next; 27 }G[2*maxn]; 28 void addEdge(int u,int v,int w) 29 { 30 G[num].to=v; 31 G[num].w=w; 32 G[num].next=head[u]; 33 head[u]=num++; 34 } 35 struct Que 36 { 37 int u,v; 38 int w; 39 }que[maxn]; 40 void DFS1(int u,int f,int d) 41 { 42 fa[u]=f; 43 dep[u]=d; 44 siz[u]=1; 45 for(int i=head[u];~i;i=G[i].next) 46 { 47 int v=G[i].to; 48 if(v == f) 49 continue; 50 51 w[v]=G[i].w; 52 DFS1(v,u,d+1); 53 54 siz[u] += siz[v]; 55 if(son[u] == -1 || siz[v] > siz[son[u]]) 56 son[u]=v; 57 } 58 } 59 void DFS2(int u,int a,int &k) 60 { 61 top[u]=a; 62 tid[u]=++k; 63 s[k]=s[k-1]+w[u]; 64 if(son[u] == -1) 65 return ; 66 DFS2(son[u],a,k); 67 for(int i=head[u];~i;i=G[i].next) 68 { 69 int v=G[i].to; 70 if(v != son[u] && v != fa[u]) 71 DFS2(v,v,k); 72 } 73 } 74 int Find(int u,int v)//求解节点u到节点v的路径权值和 75 { 76 int x=top[u]; 77 int y=top[v]; 78 int ans=0; 79 80 while(x != y) 81 { 82 if(dep[x] > dep[y]) 83 { 84 swap(u,v); 85 swap(x,y); 86 } 87 ans += s[tid[v]]-s[tid[y]-1]; 88 v=fa[y]; 89 y=top[v]; 90 } 91 if(u != v) 92 { 93 if(dep[u] > dep[v]) 94 swap(u,v); 95 ans += s[tid[v]]-s[tid[u]]; 96 } 97 98 return ans; 99 } 100 101 int diff[maxn]; 102 void Update(int u,int v) 103 { 104 int x=top[u]; 105 int y=top[v]; 106 if(x == y)//如果 u,v 在同一条重链上 107 { 108 if(dep[u] > dep[v]) 109 swap(u,v); 110 //节点son[u]到节点v的新编号是连续的 111 diff[tid[son[u]]]++; 112 diff[tid[v]+1]--; 113 return ; 114 } 115 else//如果不在 116 { 117 if(dep[x] > dep[y]) 118 { 119 swap(x,y); 120 swap(u,v); 121 } 122 //节点 top[v]与节点v的新编号是连续的 123 diff[tid[y]]++; 124 diff[tid[v]+1]--; 125 Update(u,fa[y]); 126 } 127 } 128 /** 129 cnt:一共有cnt个权值和 > mid 130 ans1:这cnt个权值和最大的比mid大多少 131 ans2:这cnt个路径中权值最大的公共路径 132 */ 133 bool Check(int mid) 134 { 135 int cnt=0; 136 int ans1=0; 137 for(int i=1;i <= m;++i) 138 { 139 int u=que[i].u; 140 int v=que[i].v; 141 if(que[i].w > mid) 142 { 143 cnt++; 144 ans1=max(ans1,que[i].w-mid); 145 Update(u,v); 146 } 147 } 148 int ans2=0; 149 int tot=0; 150 for(int i=1;i <= n;++i) 151 { 152 tot += diff[i]; 153 diff[i]=0; 154 if(tot == cnt) 155 ans2=max(ans2,s[i]-s[i-1]); 156 } 157 //只有ans2 >= ans1 才能够使最大的权值和小于等于mid 158 return ans2 >= ans1; 159 } 160 int Solve() 161 { 162 DFS1(1,1,1); 163 int k=0; 164 DFS2(1,1,k); 165 166 for(int i=1;i <= m;++i) 167 { 168 int u=que[i].u; 169 int v=que[i].v; 170 que[i].w=Find(u,v); 171 } 172 173 int l=-1,r=300000000+50; 174 while(r-l > 1)//二分答案 175 { 176 int mid=l+((r-l)>>1); 177 if(Check(mid)) 178 r=mid; 179 else 180 l=mid; 181 } 182 return r; 183 } 184 void Init() 185 { 186 num=0; 187 mem(head,-1); 188 mem(diff,0); 189 mem(son,-1); 190 mem(s,0); 191 } 192 int main() 193 { 194 // freopen("C:\\Users\\hyacinthLJP\\Desktop\\in&&out\\BZOJ\\4326.in","r",stdin); 195 while(~scanf("%d%d",&n,&m)) 196 { 197 Init(); 198 for(int i=1;i < n;++i) 199 { 200 int u,v,w; 201 scanf("%d%d%d",&u,&v,&w); 202 addEdge(u,v,w); 203 addEdge(v,u,w); 204 } 205 for(int i=1;i <= m;++i) 206 scanf("%d%d",&que[i].u,&que[i].v); 207 208 printf("%d\n",Solve()); 209 } 210 return 0; 211 }