【CodeChef COUNTARI】Arithmetic Progressions(分块 + FFT)

题目链接:【CodeChef COUNTARI】Arithmetic Progressions

题目大意:给定一个长度为 n 的数列,求数列中有多少个三元组 ( i , j , k ) ,满足:
+ 1 i < j < k n
+ a i a j = a j a k

n 100000 m = m a x { a i } 30000

式子化为: 2 a j = a i + a k 。考虑枚举 j 的位置,再将数列前半部分的生成函数与后半部分的生成函数做卷积,即可得到有多少对 ( i , k ) ,使得 2 a j = a i + a k 。这样的时间复杂度是 Θ ( n m log m ) 的,会超时。考虑如何减少卷积次数。

考虑将数组分成 b l o c k s 块, ( i , j , k ) 的位置有三种情况,我们分别计算即可。

  • 情况 1 ( i , j , k ) 在同一块内。在块内部枚举其中两个数的位置,就不难得出第三个数的值,用一个数组维护一下即可。时间复杂度 Θ ( b l o c k s ( n b l o c k s ) 2 ) = Θ ( n 2 b l o c k s )
  • 情况 2 ( i , j , k ) 有两个数在同一块内,而另一个数在另一块内。与刚才思路类似,枚举同一块内的两个数即可。时间复杂度 Θ ( n 2 b l o c k s )
  • 情况 3 ( i , j , k ) 所在的块两两不同。这样的话,类似一开始的思路,枚举 j 的位置,再做卷积。不同的是,卷积次数降低到了 Θ ( b l o c k s m log m )

这时,取一个合适的 b l o c k s 即可。

#include <cmath>
#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
const int maxn = 100005;
const int block = 2222;
const double pi = acos(-1.);
struct cd {
    double r, i;
    cd() {}
    cd(double real, double imag) {
        r = real, i = imag;
    }
    double& real() {
        return r;
    }
    cd operator+(const cd &x) const{
        return cd(r + x.r, i + x.i);
    }
    cd operator-(const cd &x) const{
        return cd(r - x.r, i - x.i);
    }
    cd operator*(const cd &x) const{
        return cd(r * x.r - i * x.i, r * x.i + i * x.r);
    }
};
ll ans;
int n, m, bit, lim, a[maxn], r[maxn];
int cnt[maxn], pre[maxn], nxt[maxn];
cd f[maxn], g[maxn];
void fft(cd *a, int dft) {
    for (int i = 0; i < lim; i++) {
        if (i < r[i]) {
            swap(a[i], a[r[i]]);
        }
    }
    for (int k = 1; k < lim; k <<= 1) {
        cd wn0(cos(pi / k), dft * sin(pi / k));
        for (int i = 0; i < lim; i += k << 1) {
            cd wnk(1, 0);
            for (int j = i; j < i + k; j++, wnk = wnk * wn0) {
                cd x = a[j], y = wnk * a[j + k];
                a[j] = x + y, a[j + k] = x - y;
            }
        }
    }
    if (dft == -1) {
        for (int i = 0; i < lim; i++) {
            a[i].real() /= lim;
        }
    }
}
int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", a + i);
        m = max(m, a[i]);
    }
    for (lim = 1; lim <= m << 1; lim <<= 1) bit++;
    for (int i = 0; i < lim; i++) {
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    }
    for (int i = 1; i <= n; i++) {
        nxt[a[i]]++;
    }
    for (int r, l = 1; l <= n; l += block) {
        r = min(n, l + block - 1);
        for (int i = l; i <= r; i++) {
            nxt[a[i]]--;
        }
        // Type I & II
        for (int i = l; i <= r; i++) {
            for (int j = i + 1; j <= r; j++) {
                int k = 2 * a[i] - a[j];
                if (1 <= k && k <= m) {
                    ans += cnt[k] + pre[k];
                }
                k = 2 * a[j] - a[i];
                if (1 <= k && k <= m) {
                    ans += nxt[k];
                }
            } 
            cnt[a[i]]++;
        }
        // Type III
        for (int i = 0; i <= m; i++) {
            f[i] = cd(pre[i], 0), g[i] = cd(nxt[i], 0);
        }
        for (int i = m + 1; i < lim; i++) {
            f[i] = g[i] = cd(0, 0);
        }
        fft(f, 1), fft(g, 1);
        for (int i = 0; i < lim; i++) {
            f[i] = f[i] * g[i];
        }
        fft(f, -1);
        for (int i = l; i <= r; i++) {
            ans += ll(f[2 * a[i]].real() + 0.5);
            pre[a[i]]++, cnt[a[i]]--;
        }
    }
    printf("%lld\n", ans);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_42068627/article/details/81140164