hihoCoder #1872 : Pythagorean triple

此题是 2018 年 ICPC Asia Beijing Regional Contest 的 C 题。

题目大意

求斜边长度不超过 $n$ 的勾股数的数量。不计两直角边的顺序,即勾股数 (a, b, c) 和 (b, a, c) 视作同一组。

$n \le 10^9$ 。

分析

这里一道颇为经典的计数问题。

请先阅读维基百科上的 Pythagorean triple 条目。

设斜边为 $n$ 的勾股数组有 $f(n)$ 个。又设斜边为 $n$ 的本原勾股数有 $g(n)$ 个。于是有
$ f(n) = \sum_{d \mid n} g(d)$ 。

令 $F$ 为 $f$ 的前缀和,令 $G$ 为 $g$ 的前缀和。有
\begin{aligned}
F(n) &= \sum_{i = 1}^{n} f(n) \\
&= \sum_{i = 1}^{n} \sum_{d \mid i} g(d) \\
&= \sum_{i = 1}^{n} G(\floor{n / i})
\end{aligned}

根据 $G$ 的定义,有

\begin{aligned}
G(n) &= \sum_{i = 1}^{n} g(i) \\
&=\sum_{\substack{1 \le x \le n \\ x \text{ is odd} } } \sum_{\substack{1 \le y \le n \\ y \text{ is even}}} [x^2 + y^2 \le n] [\gcd(x, y) = 1] \\
&= \frac{1}{2} \left(\sum_{1 \le x \le n } \sum_{1 \le y \le n } - \sum_{\substack{1 \le x \le n \\ x \text{ is odd}} } \sum_{\substack{1 \le y \le n \\ y \text{ is odd}} } \right) [x^2 + y^2 \le n] [\gcd(x, y) = 1]
\end{aligned}

\begin{aligned}
& \left(\sum_{1 \le x \le n } \sum_{1 \le y \le n } - \sum_{\substack{1 \le x \le n \\ x \text{ is odd}} } \sum_{\substack{1 \le y \le n \\ y \text{ is odd}} } \right) [x^2 + y^2 \le n] [\gcd(x, y) = 1] \\
&= \left(\sum_{1 \le x \le n } \sum_{1 \le y \le n } - \sum_{\substack{1 \le x \le n \\ x \text{ is odd}} } \sum_{\substack{1 \le y \le n \\ y \text{ is odd}} } \right) [x^2 + y^2 \le n] \sum_{d \mid \gcd(x, y)} \mu(d) \\
&= \sum_{1\le d \le \sqrt{n/2}} \mu(d) \left(\sum_{1 \le x \le n } \sum_{1 \le y \le n } - \sum_{\substack{1 \le x \le n \\ x \text{ is odd}} } \sum_{\substack{1 \le y \le n \\ y \text{ is odd}} } \right) [x^2 + y^2 \le n] [d \mid x] [d \mid y] \\
&= \sum_{1\le d \le \sqrt{n/2}} \mu(d) \left(\sum_{ 1 \le i \le n/d } \sum_{1 \le y \le n } - \sum_{\substack{1 \le i \le n/d \\ di \text{ is odd}} } \sum_{\substack{1 \le y \le n \\ y \text{ is odd}} } \right) [(id)^2 + y^2 \le n] [d \mid y] \\
&= \sum_{1\le d \le \sqrt{n/2}} \mu(d) \left(\sum_{ 1 \le i \le \sqrt{n}/d } \sum_{1 \le j \le \sqrt{n-(id)^2}/d } - \sum_{\substack{1 \le i \le \sqrt{n}/d \\ di \text{ is odd}} } \sum_{\substack{1 \le j \le \sqrt{n-(id)^2}/d \\ dj \text{ is odd}} } \right) 1 \\
&= \sum_{1\le d \le \sqrt{n/2}} \mu(d) \left(\sum_{ 1 \le i \le \sqrt{n}/d } \floor{ \frac{\sqrt{n-(id)^2}}{d} } - [d \text{ is odd}] \sum_{\substack{1 \le i \le \sqrt{n}/d \\ i \text{ is odd}} } \sum_{\substack{1 \le j \le \sqrt{n-(id)^2}/d \\ j \text{ is odd}} } 1 \right) \\
&= \sum_{1\le d \le \sqrt{n/2}} \mu(d) \left(\sum_{ 1 \le i \le \sqrt{n}/d } \floor{ \frac{\sqrt{n-(id)^2}}{d} } - [d \text{ is odd}] \sum_{\substack{1 \le i \le \sqrt{n}/d \\ i \text{ is odd}} } \floor{\frac{\frac{\sqrt{n-(id)^2}}{d} + 1}{2}} \right)
\end{aligned}

TODO:复杂度分析。

Implementation

预处理 $G$ 的前 2000 万项。

注意:代码不完整。

int main() {
    FAST_READ
    cout << fixed << setprecision(1);
#ifdef LOCAL
    ifstream in("main.in");
    cin.rdbuf(in.rdbuf());
#endif

    const int nax = 1e9 + 1;
//    println(nax);
    const int pre_n = 2e7;
    vl pre_G(pre_n + 1); // pre-calculate some items of G
    const int max_v = sqrt(pre_n);

    stp(i, 1, max_v + 1, 2) {
        const int i2 = i * i;
        const int max_j = sqrt(pre_n - i2);
        stp (j, 2, max_j + 1, 2) {
            if (__gcd(i, j) == 1) {
                pre_G[i2 + j * j]++;
            }
        }
    }

    rng (i, 1, pre_n + 1) {
        pre_G[i] += pre_G[i - 1];
    }

    const int max_d = sqrt(nax/2);

    const auto mu = get_mu(max_d);

    auto G = [&mu, &pre_G, pre_n](int n) {  // # of primitive Pythagorean triples with c <= n
        if (n <= pre_n) return pre_G[n];
        ll ans = 0;
        const int max_gcd = sqrt(n / 2);
        const int tmp = (int)sqrt(n);
        rng (d, 1, max_gcd + 1) {
            ll sum = 0;
            const int max_i = tmp / d;
            for (int i = 1; i <= max_i; ) {
                const int arg = int(sqrt(n - sq(i*d))) / d;
                const int j = int(sqrt(n - sq(arg * d))) / d;
                sum += (j - i + 1) * arg;
                if (d & 1) {
                    sum -= (j - i + 1 + (i & 1)) / 2 * ((arg + 1) / 2);
                }
                i = j + 1;
            }
            ans += sum * mu[d];
        }
        return ans / 2;
    };



    auto F = [&](int n) {  // # of Pythagorean triples with c <= n
        ll ans = 0;
        for (int i = 1; i <= n; ) {
            int arg = n / i;
            int j = n / arg;
            ans += 1LL * (j - i + 1) * G(arg);
            i = j + 1;
        }
        return ans;
    };

    int T; scan(T); rep (T) {
        int n; scan(n);
        println(F(n));
    }
int main() {
    FAST_READ
    cout << fixed << setprecision(1);
#ifdef LOCAL
    ifstream in("main.in");
    cin.rdbuf(in.rdbuf());
#endif

    const int nax = 1e9 + 1;
//    println(nax);
    const int pre_n = 2e7;
    vl pre_G(pre_n + 1); // pre-calculate some items of G
    const int max_v = sqrt(pre_n);

    stp(i, 1, max_v + 1, 2) {
        const int i2 = i * i;
        const int max_j = sqrt(pre_n - i2);
        stp (j, 2, max_j + 1, 2) {
            if (__gcd(i, j) == 1) {
                pre_G[i2 + j * j]++;
            }
        }
    }

    rng (i, 1, pre_n + 1) {
        pre_G[i] += pre_G[i - 1];
    }

    const int max_d = sqrt(nax/2);

    const auto mu = get_mu(max_d);

    auto G = [&mu, &pre_G, pre_n](int n) {  // # of primitive Pythagorean triples with c <= n
        if (n <= pre_n) return pre_G[n];
        ll ans = 0;
        const int max_gcd = sqrt(n / 2);
        const int tmp = (int)sqrt(n);
        rng (d, 1, max_gcd + 1) {
            ll sum = 0;
            const int max_i = tmp / d;
            for (int i = 1; i <= max_i; ) {
                const int arg = int(sqrt(n - sq(i*d))) / d;
                const int j = int(sqrt(n - sq(arg * d))) / d;
                sum += (j - i + 1) * arg;
                if (d & 1) {
                    sum -= (j - i + 1 + (i & 1)) / 2 * ((arg + 1) / 2);
                }
                i = j + 1;
            }
            ans += sum * mu[d];
        }
        return ans / 2;
    };



    auto F = [&](int n) {  // # of Pythagorean triples with c <= n
        ll ans = 0;
        for (int i = 1; i <= n; ) {
            int arg = n / i;
            int j = n / arg;
            ans += 1LL * (j - i + 1) * G(arg);
            i = j + 1;
        }
        return ans;
    };

    int T; scan(T); rep (T) {
        int n; scan(n);
        println(F(n));
    }





#ifdef LOCAL
    cout << "Time elapsed: " << 1.0 * clock() / CLOCKS_PER_SEC << " s.\n";
#endif
    return 0;
}
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Patt/p/10631987.html
今日推荐