正解 边带权并查集
具体思路,跟一般的边带权并查集一样,一边加入并查集,维护每个点到其根的dis,还没加入的边默认正确,在加入的边中进行判断是否合法。
可以知道,若一次询问两个士兵 (设为l,r) 在一个集合中,为了满足前面的要求,又要满足这次的要求,必须 dis[l] + d == dis[r] ,dis[l]代表l点到根的距离,dis[r]代表r到根的距离,若不满足则一定不合法。
如果两个士兵不在同一个并查集中,为了满足dis合法,需满足上面的要求,于是可以用下图理解:
于是可以知道dis[fy] = dis[x] + d - dis[y] 合法。
这样统计完了后,直接枚举每个点,分别取此点并查集中的最大值和最小值,枚举每个集合,答案就是最大值与最小值相差最大的那个值。
代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define pt putchar
#define gc getchar
#define ko pt(' ')
#define ex pt('\n')
const int MAXN = 1e5 + 5;
const int INF = 999999999;
int n,m,fa[MAXN];
int val[MAXN],Max[MAXN],Min[MAXN];
void in(int &x)
{
int num = 0,f = 1; char ch = gc();
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = gc();}
while(ch >= '0' && ch <= '9') {num = (num<<3) + (num<<1) + (ch-'0'); ch = gc();}
x = num*f;
}
void out(int x)
{
if(x < 0) x = -x,pt('-');
if(x > 9) out(x/10);
pt(x % 10 + '0');
}
int find(int x)
{
if(fa[x] == x) return x;
else
{
int tmp = find(fa[x]);
val[x] += val[fa[x]];
fa[x] = tmp;
return tmp;
}
}
int main()
{
in(n); in(m);
// if(m == 0) {cout << 0; return 0;}
for(int i = 1;i <= n;i++) fa[i] = i;
for(int i = 1;i <= m;i++)
{
int x,y,d;
in(x),in(y),in(d);
int fx = find(x),disx = val[x];
int fy = find(y),disy = val[y];
if(fx == fy)
{
if(d + disx != disy){
printf("impossible");
return 0;
}
}
else {
fa[fy] = fx;
val[fy] = disx + d - disy;
}
}
int ans = 0;
for(int i = 1;i <= n;i++) Max[i] = -INF,Min[i] = INF;
for(int i = 1;i <= n;i++)
{
int root = find(i);
Max[root] = max(Max[root],val[i]);
Min[root] = min(Min[root],val[i]);
}
for(int i = 1;i <= n;i++)
if(Min[i] != INF) ans = max(ans,Max[i] - Min[i]);
out(ans);
return 0;
}
/*
3 3
1 2 1
2 3 1
1 3 2
*/
最短路+几乎不能算DP的DP
先跑一边最短路,求出st到各个节点的dis,然后用两个数组:
f[i]表示从st到i点的最短路条数,g[i]表示从i点到st与ed的最短路上的最短路条数。
明显不考虑相不相遇,所有走的方案数等于g[ed]2。应该很好理解(从st走有g[ed]种选择,从ed走有g[ed]种选择,根据乘法原理)
考虑不合法情况有哪些:
边相遇
当有两个(u为当前点,v为与之相连点) 在最短路上的点,v到st的距离乘以2大于st到ed的距离, u到st的距离乘以2小于到st到ed的距离,那么他们一定可以在这条边上相遇(可以画个图理解一下),那么方案数就应该减去(f[u]*g[v])2。(同上乘法原理)
点相遇
两人到一点上相遇的情况就简单了,只要st到当前点i == ed到当前点i距离相等即可,即st到点i距离的两倍等于st到ed距离即可。同样减去(f[i]*g[i])2。
减去后即为答案。
上代码:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define pt putchar
#define gc getchar
#define ko pt(' ')
#define ex pt('\n')
const int MAXN = 2e5 + 5;
const int MOD = 1e9 + 7;
int n,m,st,ed;
ll ans = 0,INF;
struct edge
{
int next,to; ll w;
}e[MAXN<<1];
int head[MAXN<<1],cnt = 0,idx[MAXN];
ll f[MAXN],g[MAXN],dis[MAXN];
bool vis[MAXN];
void add(int u,int v,ll val)
{
e[++cnt].next = head[u]; e[cnt].to = v; e[cnt].w = val; head[u] = cnt;
e[++cnt].next = head[v]; e[cnt].to = u; e[cnt].w = val; head[v] = cnt;
}
void in(int &x)
{
int num = 0,f = 1; char ch = gc();
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = gc();}
while(ch >= '0' && ch <= '9') {num = (num<<3) + (num<<1) + (ch-'0'); ch = gc();}
x = num*f;
}
void lin(ll &x)
{
ll num = 0,f = 1; char ch = gc();
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = gc();}
while(ch >= '0' && ch <= '9') {num = (num<<3) + (num<<1) + (ch-'0'); ch = gc();}
x = num*f;
}
void out(ll x)
{
if(x < 0) x = -x,pt('-');
if(x > 9) out(x/10);
pt(x % 10 + '0');
}
priority_queue<pair<ll,int> > q;
void dijkstra()
{
memset(dis,0x3f,sizeof dis);
memset(vis,0,sizeof vis);
INF = dis[0]; dis[st] = 0;
q.push(make_pair(0,st));
while(!q.empty())
{
int x = q.top().second; q.pop();
if(vis[x]) continue; vis[x] = 1;
for(int i = head[x];i;i = e[i].next)
{
int to = e[i].to,w = e[i].w;
if(dis[to] > dis[x] + w){
dis[to] = dis[x] + w;
q.push(make_pair(-dis[to],to));
}
}
}
}
bool cmp(int a,int b){
return dis[a] < dis[b];
}
int main()
{
in(n); in(m);
in(st); in(ed);
for(int i = 1;i <= m;i++)
{
int u,v; ll val;
in(u),in(v),lin(val);
add(u,v,val);
}
dijkstra();
for(int i = 1;i <= n;i++) idx[i] = i;
sort(idx+1,idx+n+1,cmp);
f[st] = 1,g[ed] = 1;
for(int i = 1;i <= n;i++)
{
int x = idx[i];
for(int j = head[x];j;j = e[j].next)
{
int to = e[j].to; ll w = e[j].w;
if(dis[to] == dis[x] + w)
f[to] = (f[x] + f[to]) % MOD;
}
}
for(int i = n;i >= 1;i--)
{
int x = idx[i];
for(int j = head[x];j;j = e[j].next)
{
int to = e[j].to; ll w = e[j].w;
if(dis[to] == dis[x] - w)
g[to] = (g[x] + g[to]) % MOD;
}
}
ans = f[ed]*f[ed] % MOD;
for(int i = 1;i <= n;i++)
for(int j = head[i];j;j = e[j].next)
{
int to = e[j].to; ll w = e[j].w;
if(dis[to] == dis[i] + w)
if((dis[to] << 1) > dis[ed] && (dis[i] << 1) < dis[ed])
ans = ((ans - f[i] % MOD * g[to] % MOD * f[i] % MOD * g[to] % MOD) + MOD) % MOD;
}
for(int i = 1;i <= n;i++)
if((dis[i] << 1) == dis[ed])
ans = ((ans - f[i] % MOD * g[i] % MOD * f[i] % MOD * g[i] % MOD) % MOD + MOD) % MOD;
out(ans % MOD);
return 0;
}
/*
4 4
1 3
1 2 1
2 3 1
3 4 1
4 1 1
*/
欧拉序 + 树状数组维护链相交
首先欧拉序
概念:
指从根结点出发,按dfs的顺序在绕回原点所经过所有点的顺序。
这里用欧拉序处理LCA,方便进行树状数组维护。
首先要说这道题的一个关键定理:
一条链与另一条链相交,充要条件是:一条链上的深度最浅的点(就是LCA)在另一条链上
于是可以想到这道题树状数组的运用就是在欧拉序上维护这条链上有多少个其他链的LCA。
这样将每一条链上的其他链的LCA数量加起来再减去会重复的每个LCA数量(设为n)的平方与
C(2,n)即为答案。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define pt putchar
#define gc getchar
#define ko pt(' ')
#define ex pt('\n')
const int MAXN = 2e6 + 5;
int n,m;
ll ans = 0;
struct edge
{
int next,to;
}e[MAXN<<1];
struct question
{
int u,v,up;
}node[MAXN<<1];
int head[MAXN<<1],cnt = 0;
bool vis[MAXN<<1];
int dep[MAXN<<1],fa[MAXN],pa[MAXN],tot[MAXN];
int enter[MAXN],getout[MAXN],sum[MAXN];
vector<int> link[MAXN],idx[MAXN];
void add(int u,int v)
{
e[++cnt].next = head[u]; e[cnt].to = v; head[u] = cnt;
e[++cnt].next = head[v]; e[cnt].to = u; head[v] = cnt;
}
void in(int &x)
{
int num = 0,f = 1; char ch = gc();
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = gc();}
while(ch >= '0' && ch <= '9') {num = (num<<3) + (num<<1) + (ch-'0'); ch = gc();}
x = num*f;
}
void out(ll x)
{
if(x < 0) x = -x,pt('-');
if(x > 9) out(x/10);
pt(x % 10 + '0');
}
int lowbit(int x) {return x & (-x);}
void updata(int x,int val)
{
while(x <= n<<1)
{
sum[x] += val;
x += lowbit(x);
}
}
int ask(int x)
{
int s = 0;
while(x)
{
s += sum[x];
x -= lowbit(x);
}
return s;
}
int find(int x) {return x == fa[x] ? x : fa[x] = find(fa[x]);}
void dfs(int x,int fr)
{
enter[x] = ++cnt; dep[x] = dep[fr] + 1;
pa[x] = fr,vis[x] = 1;
for(int i = 0;i < link[x].size();i++)
if(vis[link[x][i]]) node[idx[x][i]].up = find(link[x][i]);
for(int i = head[x];i;i = e[i].next)
{
int to = e[i].to;
if(to == fr) continue;
dfs(to,x);
}
getout[x] = ++cnt;
fa[x] = fr;
}
ll C(int n)
{
if(n < 2) return 0;
return (1ll*n*(n-1)) >> 1;
}
int main()
{
in(n); in(m);
for(int i = 1;i < n;i++)
{
int x,y; in(x),in(y);
add(x,y);
}
for(int i = 1;i <= m;i++)
{
in(node[i].u),in(node[i].v);
link[node[i].u].push_back(node[i].v),idx[node[i].u].push_back(i);
link[node[i].v].push_back(node[i].u),idx[node[i].v].push_back(i);
}
for(int i = 1;i <= n;i++) fa[i] = i;
cnt = 0; dfs(1,0);
for(int i = 1;i <= m;i++)
{
updata(enter[node[i].up],1);
updata(getout[node[i].up],-1);
tot[node[i].up]++;
}
for(int i = 1;i <= m;i++)
{
ans = ans + 1ll * ask(enter[node[i].u]) + 1ll * ask(enter[node[i].v]);
ans = ans - 1ll * ask(enter[node[i].up]) - 1ll * ask(enter[pa[node[i].up]]);
}
for(int i = 1;i <= n;i++)
ans = ans - 1ll * tot[i] * tot[i] + C(tot[i]);
out(ans);
return 0;
}
/*
6 5
1 2
2 3
2 4
1 5
4 6
1 4
5 6
3 4
1 6
2 3
*/