模板 - 类欧几里得算法

用来快速求解 $\sum\limits_{i=0}^{n}\lfloor \frac{ai+b}{c} \rfloor,\sum\limits_{i=0}^{n}{\lfloor \frac{ai+b}{c} \rfloor}^2,\sum\limits_{i=0}^{n}i\lfloor \frac{ai+b}{c} \rfloor $

有多快呢?据说是log的?反正abc取1e9可以200ms过1e5组询问……

#include <bits/stdc++.h>

typedef long long ll;
constexpr int mod = 998244353;
constexpr ll inv2 = 499122177;
constexpr ll inv6 = 166374059;

ll f(ll a, ll b, ll c, ll n);
ll g(ll a, ll b, ll c, ll n);
ll h(ll a, ll b, ll c, ll n);

struct Query {
    ll f, g, h;
};

Query solve(ll a, ll b, ll c, ll n) {
    Query ans, tmp;
    if(a == 0) {
        ans.f = (n + 1) * (b / c) % mod;
        ans.g = (b / c) * n % mod * (n + 1) % mod * inv2 % mod;
        ans.h = (n + 1) * (b / c) % mod * (b / c) % mod;
        return ans;
    }
    if(a >= c || b >= c) {
        tmp = solve(a % c, b % c, c, n);
        ans.f = (tmp.f + (a / c) * n % mod * (n + 1) % mod * inv2 % mod + (b / c) * (n + 1) % mod) % mod;
        ans.g = (tmp.g + (a / c) * n % mod * (n + 1) % mod * (2 * n + 1) % mod * inv6 % mod + (b / c) * n % mod * (n + 1) % mod * inv2 % mod) % mod;
        ans.h = ((a / c) * (a / c) % mod * n % mod * (n + 1) % mod * (2 * n + 1) % mod * inv6 % mod +
                 (b / c) * (b / c) % mod * (n + 1) % mod + (a / c) * (b / c) % mod * n % mod * (n + 1) % mod +
                 tmp.h + 2 * (a / c) % mod * tmp.g % mod + 2 * (b / c) % mod * tmp.f % mod) % mod;
        return ans;
    }
    ll m = (a * n + b) / c;
    tmp = solve(c, c - b - 1, a, m - 1);
    ans.f = (n * (m % mod) % mod - tmp.f) % mod;
    ans.g = (n * (n + 1) % mod * (m % mod) % mod - tmp.f - tmp.h) % mod * inv2 % mod;
    ans.h = (n * (m % mod) % mod * ((m + 1) % mod) % mod - 2 * tmp.g - 2 * tmp.f - ans.f) % mod;
    return ans;
}

inline char nc() {
    static char buf[1000000], *p1 = buf, *p2 = buf;
    return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin), p1 == p2) ? EOF : *p1++;
}

inline ll read() {
    ll res = 0;
    char ch;
    do
        ch = nc();
    while(ch < 48 || ch > 57);
    do
        res = res * 10 + ch - 48, ch = nc();
    while(ch >= 48 && ch <= 57);
    return res;
}


int main() {
    ll t = read();
    ll n, a, b, c;
    while(t--) {
        n = read(), a = read(), b = read(), c = read();
        Query ans = solve(a, b, c, n);
        printf("%lld %lld %lld\n", (ans.f + mod) % mod, (ans.h + mod) % mod, (ans.g + mod) % mod);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Yinku/p/10776831.html