Tsinsen D455 El

传送门
题目大意

给你一棵 n ( 1 n 10 5 ) 个点的树,每个结点一个权值 v i 。定义一条路径的权值:设路径的起点为 s ,终点为 t t 可能等于 s ),在行进过程中经过了 L 个点,依次为 t 1 , t 2 , , t L ,则权值为:

t 1 k 0 + t 2 k 1 + + t L k L 1

如果权值在模 P P 为质数)意义下等于 X ,那么它将显现出绿色;否则它将显出红色。

定义两条路径的颜色的叠加:如果两条路径都是绿色或者红色,那么它们叠加的颜色不变,否则它们显黄色。

问:有多少个 ( s , m i d , t ) 满足路径 ( s , m i d ) ( m i d , t ) 叠加的颜色等于路径 ( s , t ) 的颜色。

思路

我们相当于是要找一个三元组 ( a , b , c ) a b b c a c 都同色。换句话说,那个颜色叠加并没有什么用处,因此只要有不同色就对答案没有贡献。

我们考虑用三元组总数减去存在不同色的三元组个数。不同色的三元组个数怎么算呢?考虑 a b b c 不同色的情况,它一定被包含在了 ( a , b , c ) 这个三元组中。同样的还有: a b c b 被包含在 ( a , b , c ) ( c , a , b ) 中, b a b c 被包含在 ( b , a , c ) ( b , c , a ) 中。如果 a b b c 不同色,那么 a b b c 中有且仅有一个与 a c 不同色。所以一个不同色的三元组一定对应两条有公共点的不同色的路径。我们求出不同色的路径个数的总数再除以 2 就是不同色的三元组个数了。(注意,上面也说到了,有些不同色的角对应两个三元组,还要乘以二)。

现在我们不用考虑三元组,只用考虑两条端点重合的路径。而这种路径显然可以拆成两部分来看,那就是一条从根结点出发(down)或者以根结点结尾(up)的链。这种角(两条端点重合的路径)显然有三种形态:down down,down up,up up。上面一段也说了,down up 对应一个三元组,而 down down 和 up up 对应两个三元组,所以还要乘以 2

现在考虑用点分治解决上面的问题。设当前分治中心为 r o o t ,先考虑如何计算 r o o t 的答案。对于 down down 而言,如果我们知道了有多少条路径的权值同余 X (设为 t 1 ),那么对答案的贡献就是 2 t 1 ( n t 1 ) (前面说了不要忘记乘以 2 )。

对于 up up 而言,我们需要记录从下往上的路径的权值,可以通过递推实现:

f ( t o ) = k × f ( p a r e n t ) + v t o

对答案的贡献同样是 2 t 2 ( n t 2 )

现在考虑 down up 的情况。设向下同余 X t 3 个,向上同余 X t 4 个,那么对答案的贡献为 t 3 ( n t 4 ) + t 4 ( n t 3 ) 。注意,这道题允许重合,因此就不用管去重了。

由于要统计所有点,因此考虑如何利用当前分治中心的数据去计算别的点,根据点分治的理论,别的点的路径必须经过分治中心。那么考虑这条路径的权值,显然为:

f u p ( f r o m ) + k d e p t h f r o m × f d o w n ( t o )

注意这个 f d o w n 不计算当前分治中心,否则会算重。

如果上式同余 X ,那么有:

f d o w n ( t o ) X f u p ( f r o m ) k d e p t h f r o m

我们把左右式分别保存在哈希表中。对于分治中心的一棵子树,我们先把它从哈希表中剔除,再计算这棵子树的答案,最后还原哈希表。设当前计算 u 的答案,计算方法是:以 u 为终点的答案加上查询右式哈希表有多少个等于代入 u 后的左式,以 u 为起点的答案加上查询左式哈希表有多少个等于代入 u 后的右式。

为了效率,我们只保存这两个哈希表,那么前面查询以分治中心为终点的答案的方法就要变一下。本来应该查询 X ,但是现在因为哈希表里存的是 X f u p ( f r o m ) k d e p t h f r o m (而不是算了分治中心的 f u p ),所以查询内容变了。代入本来应该满足条件的式子:

X ( X k d e p t h f r o m × v u p ) k d e p t h f r o m

减去那一堆是因为我们没有算分治中心对权值的贡献,加上它后应该同余 X 。上式显然等于 v u p ,所以查询 v u p (而不是 X )即可。当然这里如果你不想推也可以直接 O ( n ) 统计。

参考代码
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <cassert>
#include <cctype>
#include <climits>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <stack>
#include <queue>
#include <deque>
#include <map>
#include <set>
#include <bitset>
#include <list>
#include <functional>
using LL = long long;
using ULL = unsigned long long;
using std::cin;
using std::cout;
using std::endl;
using INT_PUT = LL;
INT_PUT readIn()
{
    INT_PUT a = 0; bool positive = true;
    char ch = getchar();
    while (!(ch == '-' || std::isdigit(ch))) ch = getchar();
    if (ch == '-') { positive = false; ch = getchar(); }
    while (std::isdigit(ch)) { a = a * 10 - (ch - '0'); ch = getchar(); }
    return positive ? -a : a;
}
void printOut(INT_PUT x)
{
    char buffer[20]; int length = 0;
    if (x < 0) putchar('-'); else x = -x;
    do buffer[length++] = -(x % 10) + '0'; while (x /= 10);
    do putchar(buffer[--length]); while (length);
}

const int maxn = int(1e5) + 5;
int n, mod, k, invk, X;
int pwr[maxn];
int invpwr[maxn];
int v[maxn];
LL power(LL x, int y)
{
    LL ret = 1;
    while (y)
    {
        if (y & 1) ret = ret * x % mod;
        x = x * x % mod;
        y >>= 1;
    }
    return ret;
}

using Graph = std::vector<std::vector<int>>;
Graph G;

#define RunInstance(x) delete new x
struct work
{
    struct unordered_map
    {
        struct Node
        {
            int key;
            int val = 0;
            int next = -1;
            Node(int key) : key(key) {}
        };
        unordered_map()
        {
            nodes.reserve(maxn);
            std::memset(head, -1, sizeof(head));
        }
        std::vector<Node> nodes;
        static const int size = int(1e6) + 7;
        int head[size];
        int& operator[](int x)
        {
            int cnt = head[x % size];
            if (~cnt)
            {
                if (nodes[cnt].key == x)
                    return nodes[cnt].val;
                while (~nodes[cnt].next)
                {
                    cnt = nodes[cnt].next;
                    if (nodes[cnt].key == x)
                        return nodes[cnt].val;
                }
                nodes[cnt].next = nodes.size();
                nodes.push_back(x);
                return nodes.back().val;
            }
            else
            {
                head[x % size] = nodes.size();
                nodes.push_back(x);
                return nodes.back().val;
            }
        }
        int operator[](int x) const
        {
            int cnt = head[x % size];
            while (~cnt)
            {
                if (nodes[cnt].key == x)
                    return nodes[cnt].val;
                cnt = nodes[cnt].next;
            }
            return 0;
        }
        void clear()
        {
            for (const auto& t : nodes)
                head[t.key % size] = -1;
            nodes.clear();
        }
    };
    bool vis[maxn]{};
    int size[maxn];
    void DFS1(int node, int parent)
    {
        size[node] = 1;
        for (int to : G[node]) if (!vis[to] && to != parent)
        {
            DFS1(to, node);
            size[node] += size[to];
        }
    }
    int findRoot(int node, int parent, int s)
    {
        for (int to : G[node]) if (!vis[to] && to != parent)
        {
            if (size[to] >= (s >> 1))
                return findRoot(to, node, s);
        }
        return node;
    }
    LL countDown[maxn]{};
    LL countUp[maxn]{};

    unordered_map mapDown;
    unordered_map mapUp;

    int depth[maxn];
    int dis1[maxn];
    int dis2[maxn];
    int dfn[maxn];
    int end[maxn];
    int seq[maxn];
    int stamp;
    int f(int u)
    {
        return (LL)invpwr[depth[u]] * (X - dis2[u] + mod) % mod;
    }
    void DFS2(int node, int parent)
    {
        stamp++;
        seq[stamp] = node;
        dfn[node] = stamp;
        dis1[node] = (dis1[parent] + (LL)v[node] * pwr[depth[node]]) % mod;
        for (int to : G[node]) if (!vis[to] && to != parent)
        {
            depth[to] = depth[node] + 1;
            dis2[to] = ((LL)k * dis2[node] + v[to]) % mod;
            DFS2(to, node);
        }
        end[node] = stamp;
    }
    void solve(int node)
    {
        DFS1(node, 0);
        node = findRoot(node, 0, size[node]);
        vis[node] = true;

        depth[node] = 0;
        dis1[0] = 0;
        dis2[node] = 0;
        stamp = 0;
        mapDown.clear();
        mapUp.clear();
        DFS2(node, 0);

        for (int i = 1; i <= stamp; i++)
        {
            mapDown[dis1[seq[i]]]++;
            mapUp[f(seq[i])]++;
        }
        countDown[node] += mapDown[X];
        countUp[node] += mapUp[v[node]];
        for (int to : G[node]) if (!vis[to])
        {
            for (int i = dfn[to]; i <= end[to]; i++)
            {
                mapDown[dis1[seq[i]]]--;
                mapUp[f(seq[i])]--;
            }
            for (int i = dfn[to]; i <= end[to]; i++)
            {
                countDown[seq[i]] += mapDown[f(seq[i])];
                countUp[seq[i]] += mapUp[dis1[seq[i]]];
            }
            for (int i = dfn[to]; i <= end[to]; i++)
            {
                mapDown[dis1[seq[i]]]++;
                mapUp[f(seq[i])]++;
            }
        }

        for (int to : G[node]) if (!vis[to])
        {
            solve(to);
        }
    }


    work()
    {
        solve(1);
        LL ans = 0;
        for (int i = 1; i <= n; i++)
        {
            ans += LL(2) * countDown[i] * (n - countDown[i]);
            ans += LL(2) * countUp[i] * (n - countUp[i]);
            ans += (LL)countDown[i] * (n - countUp[i]);
            ans += (LL)countUp[i] * (n - countDown[i]);
        }
        ans >>= 1;
        ans = (LL)n * n * n - ans;
        printOut(ans);
    }
};

void run()
{
    n = readIn();
    mod = readIn();
    k = readIn();
    X = readIn();
    pwr[0] = 1;
    for (int i = 1; i <= n; i++)
        pwr[i] = (LL)pwr[i - 1] * k % mod;
    invk = power(k, mod - 2);
    invpwr[0] = 1;
    for (int i = 1; i <= n; i++)
        invpwr[i] = (LL)invpwr[i - 1] * invk % mod;

    for (int i = 1; i <= n; i++)
        v[i] = readIn();
    G.resize(n + 1);
    for (int i = 1; i < n; i++)
    {
        int from = readIn();
        int to = readIn();
        G[from].push_back(to);
        G[to].push_back(from);
    }

    RunInstance(work);
}

int main()
{
#ifndef LOCAL
    freopen("el.in", "r", stdin);
    freopen("el.out", "w", stdout);
#endif
    run();
    return 0;
}
总结

用两种结点给三角形的顶点染色,那么如果该三角形的三个顶点不同色,就有且仅有两对点是异色的。

可以把路径拆开搞。

猜你喜欢

转载自blog.csdn.net/lycheng1215/article/details/80780643