题意:给定一张含有n个顶点和m条边的图。顶点有两种类型 0和1。第i条边的权值为2i 。求任意顶点0和任意顶点1之间的距离和。
思路:这题跟 这道2018CCPC网络赛 很类似。
由于边有2e5条。所以直接用dijkstra求最短距离是不可能的,复杂度直接炸了。 因为第i条边的权值是2i ,所以第i条的权值是大于前i-1条权值之和。也就是说。如果前i-1条边图已经联通了,那么我们就不需要剩下的边了。因为就算两点间需要经过i-1条边,那这个距离也是小于第i条边的。
所以我们可以先把图转成一颗树(最小生成树),然后再树上想办法。在一颗树上,两点之间的路径是唯一的。所以要求任意顶点0和任意顶点1之间的距离和,其实可以先求出每条边对最后结果的贡献度,最后求和。每条边的贡献度为 边长*经过的次数。
现在问题是如果求每条边被经过的次数?
例如边<a,b >。假设在树中a为b的父节点。那么<a,b >这条边的经过次数=以b为根节点的子树中0类别的节点数*子树以外类别1的节点数 +子树中1类别的节点数*子树以外类别0的节点数
#include<bits/stdc++.h>
#define mp(a,b) make_pair(a,b)
using namespace std;
typedef long long ll;
typedef pair<long long, int> pli;
typedef pair<int, int> pii;
typedef pair<int, double> pid;
const long long INF = 1e15;
const int maxn = 1e5 + 10;
const int mod = 1e9 + 7;
bool isPrime(long long x) {
if (x <= 1)return false;
if (x == 2)return true;
if (x % 2 == 0)return false;
long long m = sqrt(x);
for (long long i = 3; i <= m; i += 2) {
if (x % i == 0)return false;
}
return true;
}
long long gcd(long long m, long long n) {
return m % n == 0 ? n : gcd(n, m % n);
}
long long fib[90];
void initFib() {
fib[0] = fib[1] = 1;
int i;
for (i = 2; i < 90; i++) {
fib[i] = fib[i - 1] + fib[i - 2];
}
}
struct node {
long long w;
int to;
int zero, one;
node(int a, long long b, int c = 0) :to(a), w(b), zero(c), one(c) {
}
};
vector<node> G[maxn];
unordered_map<int, int> mmid;
int n, m;
int zero, one;
ll len[maxn];
ll ans;
int fa[maxn];
vector<int> v(n);
void init() {
int i;
len[0] = 1;
ll cnt = 1;
for (i = 1; i < maxn; i++) {
cnt = cnt * 2 % mod;
len[i] = cnt;
}
}
void clear() {
int i;
ans = 0;
zero = one = 0;
for (i = 0; i <= n; i++) {
fa[i] = i;
G[i].clear();
}
v.clear();
v.resize(n);
}
int find(int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
}
pii dfs(int cur, int fa, ll w) {
if (G[cur].size() == 1 && cur != 1) {
if (v[cur - 1]) {
ans = (ans + (zero) * w%mod) % mod;
return pii(0, 1);
}
else {
ans = (ans + (one) * w % mod) % mod;
return pii(1, 0);
}
}
pii p;
if (v[cur - 1]) {
p = mp(0, 1);
}
else {
p = mp(1, 0);
}
for (auto& val : G[cur]) {
if (val.to == fa)continue;
auto tmp = dfs(val.to, cur, val.w);
p.first += tmp.first;
p.second += tmp.second;
}
ll cnt = 0;
cnt = p.second * (zero - p.first) + p.first * (one - p.second);
ans = (ans + cnt * w % mod) % mod;
return p;
}
int main() {
init();
ios::sync_with_stdio(false);
int t;
cin >> t;
while (t--) {
cin >> n >> m;
clear();
for (auto& val : v) {
cin >> val;
if (val)one++;
else zero++;
}
int i;
int a, b;
for (i = 0; i < m; i++) {
cin >> a >> b;
int r1 = find(a);
int r2 = find(b);
if (r1 != r2) {
fa[r1] = r2;
G[a].push_back({
b, len[i + 1] });
G[b].push_back({
a, len[i + 1] });
}
}
dfs(1, 0, 0);
cout << ans << endl;
}
return 0;
}