传送门
题目大意
给你一棵
个点的树,每个结点一个权值
。定义一条路径的权值:设路径的起点为
,终点为
(
可能等于
),在行进过程中经过了
个点,依次为
,则权值为:
如果权值在模 ( 为质数)意义下等于 ,那么它将显现出绿色;否则它将显出红色。
定义两条路径的颜色的叠加:如果两条路径都是绿色或者红色,那么它们叠加的颜色不变,否则它们显黄色。
问:有多少个 满足路径 和 叠加的颜色等于路径 的颜色。
思路
我们相当于是要找一个三元组 , , , 都同色。换句话说,那个颜色叠加并没有什么用处,因此只要有不同色就对答案没有贡献。
我们考虑用三元组总数减去存在不同色的三元组个数。不同色的三元组个数怎么算呢?考虑 , 不同色的情况,它一定被包含在了 这个三元组中。同样的还有: , 被包含在 和 中, , 被包含在 和 中。如果 , 不同色,那么 和 中有且仅有一个与 不同色。所以一个不同色的三元组一定对应两条有公共点的不同色的路径。我们求出不同色的路径个数的总数再除以 就是不同色的三元组个数了。(注意,上面也说到了,有些不同色的角对应两个三元组,还要乘以二)。
现在我们不用考虑三元组,只用考虑两条端点重合的路径。而这种路径显然可以拆成两部分来看,那就是一条从根结点出发(down)或者以根结点结尾(up)的链。这种角(两条端点重合的路径)显然有三种形态:down down,down up,up up。上面一段也说了,down up 对应一个三元组,而 down down 和 up up 对应两个三元组,所以还要乘以 。
现在考虑用点分治解决上面的问题。设当前分治中心为 ,先考虑如何计算 的答案。对于 down down 而言,如果我们知道了有多少条路径的权值同余 (设为 ),那么对答案的贡献就是 (前面说了不要忘记乘以 )。
对于 up up 而言,我们需要记录从下往上的路径的权值,可以通过递推实现:
对答案的贡献同样是 。
现在考虑 down up 的情况。设向下同余 有 个,向上同余 有 个,那么对答案的贡献为 。注意,这道题允许重合,因此就不用管去重了。
由于要统计所有点,因此考虑如何利用当前分治中心的数据去计算别的点,根据点分治的理论,别的点的路径必须经过分治中心。那么考虑这条路径的权值,显然为:
注意这个 不计算当前分治中心,否则会算重。
如果上式同余
,那么有:
我们把左右式分别保存在哈希表中。对于分治中心的一棵子树,我们先把它从哈希表中剔除,再计算这棵子树的答案,最后还原哈希表。设当前计算 的答案,计算方法是:以 为终点的答案加上查询右式哈希表有多少个等于代入 后的左式,以 为起点的答案加上查询左式哈希表有多少个等于代入 后的右式。
为了效率,我们只保存这两个哈希表,那么前面查询以分治中心为终点的答案的方法就要变一下。本来应该查询
,但是现在因为哈希表里存的是
(而不是算了分治中心的
),所以查询内容变了。代入本来应该满足条件的式子:
减去那一堆是因为我们没有算分治中心对权值的贡献,加上它后应该同余 。上式显然等于 ,所以查询 (而不是 )即可。当然这里如果你不想推也可以直接 统计。
参考代码
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <cassert>
#include <cctype>
#include <climits>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <stack>
#include <queue>
#include <deque>
#include <map>
#include <set>
#include <bitset>
#include <list>
#include <functional>
using LL = long long;
using ULL = unsigned long long;
using std::cin;
using std::cout;
using std::endl;
using INT_PUT = LL;
INT_PUT readIn()
{
INT_PUT a = 0; bool positive = true;
char ch = getchar();
while (!(ch == '-' || std::isdigit(ch))) ch = getchar();
if (ch == '-') { positive = false; ch = getchar(); }
while (std::isdigit(ch)) { a = a * 10 - (ch - '0'); ch = getchar(); }
return positive ? -a : a;
}
void printOut(INT_PUT x)
{
char buffer[20]; int length = 0;
if (x < 0) putchar('-'); else x = -x;
do buffer[length++] = -(x % 10) + '0'; while (x /= 10);
do putchar(buffer[--length]); while (length);
}
const int maxn = int(1e5) + 5;
int n, mod, k, invk, X;
int pwr[maxn];
int invpwr[maxn];
int v[maxn];
LL power(LL x, int y)
{
LL ret = 1;
while (y)
{
if (y & 1) ret = ret * x % mod;
x = x * x % mod;
y >>= 1;
}
return ret;
}
using Graph = std::vector<std::vector<int>>;
Graph G;
#define RunInstance(x) delete new x
struct work
{
struct unordered_map
{
struct Node
{
int key;
int val = 0;
int next = -1;
Node(int key) : key(key) {}
};
unordered_map()
{
nodes.reserve(maxn);
std::memset(head, -1, sizeof(head));
}
std::vector<Node> nodes;
static const int size = int(1e6) + 7;
int head[size];
int& operator[](int x)
{
int cnt = head[x % size];
if (~cnt)
{
if (nodes[cnt].key == x)
return nodes[cnt].val;
while (~nodes[cnt].next)
{
cnt = nodes[cnt].next;
if (nodes[cnt].key == x)
return nodes[cnt].val;
}
nodes[cnt].next = nodes.size();
nodes.push_back(x);
return nodes.back().val;
}
else
{
head[x % size] = nodes.size();
nodes.push_back(x);
return nodes.back().val;
}
}
int operator[](int x) const
{
int cnt = head[x % size];
while (~cnt)
{
if (nodes[cnt].key == x)
return nodes[cnt].val;
cnt = nodes[cnt].next;
}
return 0;
}
void clear()
{
for (const auto& t : nodes)
head[t.key % size] = -1;
nodes.clear();
}
};
bool vis[maxn]{};
int size[maxn];
void DFS1(int node, int parent)
{
size[node] = 1;
for (int to : G[node]) if (!vis[to] && to != parent)
{
DFS1(to, node);
size[node] += size[to];
}
}
int findRoot(int node, int parent, int s)
{
for (int to : G[node]) if (!vis[to] && to != parent)
{
if (size[to] >= (s >> 1))
return findRoot(to, node, s);
}
return node;
}
LL countDown[maxn]{};
LL countUp[maxn]{};
unordered_map mapDown;
unordered_map mapUp;
int depth[maxn];
int dis1[maxn];
int dis2[maxn];
int dfn[maxn];
int end[maxn];
int seq[maxn];
int stamp;
int f(int u)
{
return (LL)invpwr[depth[u]] * (X - dis2[u] + mod) % mod;
}
void DFS2(int node, int parent)
{
stamp++;
seq[stamp] = node;
dfn[node] = stamp;
dis1[node] = (dis1[parent] + (LL)v[node] * pwr[depth[node]]) % mod;
for (int to : G[node]) if (!vis[to] && to != parent)
{
depth[to] = depth[node] + 1;
dis2[to] = ((LL)k * dis2[node] + v[to]) % mod;
DFS2(to, node);
}
end[node] = stamp;
}
void solve(int node)
{
DFS1(node, 0);
node = findRoot(node, 0, size[node]);
vis[node] = true;
depth[node] = 0;
dis1[0] = 0;
dis2[node] = 0;
stamp = 0;
mapDown.clear();
mapUp.clear();
DFS2(node, 0);
for (int i = 1; i <= stamp; i++)
{
mapDown[dis1[seq[i]]]++;
mapUp[f(seq[i])]++;
}
countDown[node] += mapDown[X];
countUp[node] += mapUp[v[node]];
for (int to : G[node]) if (!vis[to])
{
for (int i = dfn[to]; i <= end[to]; i++)
{
mapDown[dis1[seq[i]]]--;
mapUp[f(seq[i])]--;
}
for (int i = dfn[to]; i <= end[to]; i++)
{
countDown[seq[i]] += mapDown[f(seq[i])];
countUp[seq[i]] += mapUp[dis1[seq[i]]];
}
for (int i = dfn[to]; i <= end[to]; i++)
{
mapDown[dis1[seq[i]]]++;
mapUp[f(seq[i])]++;
}
}
for (int to : G[node]) if (!vis[to])
{
solve(to);
}
}
work()
{
solve(1);
LL ans = 0;
for (int i = 1; i <= n; i++)
{
ans += LL(2) * countDown[i] * (n - countDown[i]);
ans += LL(2) * countUp[i] * (n - countUp[i]);
ans += (LL)countDown[i] * (n - countUp[i]);
ans += (LL)countUp[i] * (n - countDown[i]);
}
ans >>= 1;
ans = (LL)n * n * n - ans;
printOut(ans);
}
};
void run()
{
n = readIn();
mod = readIn();
k = readIn();
X = readIn();
pwr[0] = 1;
for (int i = 1; i <= n; i++)
pwr[i] = (LL)pwr[i - 1] * k % mod;
invk = power(k, mod - 2);
invpwr[0] = 1;
for (int i = 1; i <= n; i++)
invpwr[i] = (LL)invpwr[i - 1] * invk % mod;
for (int i = 1; i <= n; i++)
v[i] = readIn();
G.resize(n + 1);
for (int i = 1; i < n; i++)
{
int from = readIn();
int to = readIn();
G[from].push_back(to);
G[to].push_back(from);
}
RunInstance(work);
}
int main()
{
#ifndef LOCAL
freopen("el.in", "r", stdin);
freopen("el.out", "w", stdout);
#endif
run();
return 0;
}
总结
用两种结点给三角形的顶点染色,那么如果该三角形的三个顶点不同色,就有且仅有两对点是异色的。
可以把路径拆开搞。