题目链接:【CodeChef COUNTARI】Arithmetic Progressions
题目大意:给定一个长度为
的数列,求数列中有多少个三元组
,满足:
+
+
, 。
式子化为: 。考虑枚举 的位置,再将数列前半部分的生成函数与后半部分的生成函数做卷积,即可得到有多少对 ,使得 。这样的时间复杂度是 的,会超时。考虑如何减少卷积次数。
考虑将数组分成 块, 的位置有三种情况,我们分别计算即可。
- 情况 : 在同一块内。在块内部枚举其中两个数的位置,就不难得出第三个数的值,用一个数组维护一下即可。时间复杂度 。
- 情况 : 有两个数在同一块内,而另一个数在另一块内。与刚才思路类似,枚举同一块内的两个数即可。时间复杂度 。
- 情况 : 所在的块两两不同。这样的话,类似一开始的思路,枚举 的位置,再做卷积。不同的是,卷积次数降低到了 。
这时,取一个合适的 即可。
#include <cmath>
#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
const int maxn = 100005;
const int block = 2222;
const double pi = acos(-1.);
struct cd {
double r, i;
cd() {}
cd(double real, double imag) {
r = real, i = imag;
}
double& real() {
return r;
}
cd operator+(const cd &x) const{
return cd(r + x.r, i + x.i);
}
cd operator-(const cd &x) const{
return cd(r - x.r, i - x.i);
}
cd operator*(const cd &x) const{
return cd(r * x.r - i * x.i, r * x.i + i * x.r);
}
};
ll ans;
int n, m, bit, lim, a[maxn], r[maxn];
int cnt[maxn], pre[maxn], nxt[maxn];
cd f[maxn], g[maxn];
void fft(cd *a, int dft) {
for (int i = 0; i < lim; i++) {
if (i < r[i]) {
swap(a[i], a[r[i]]);
}
}
for (int k = 1; k < lim; k <<= 1) {
cd wn0(cos(pi / k), dft * sin(pi / k));
for (int i = 0; i < lim; i += k << 1) {
cd wnk(1, 0);
for (int j = i; j < i + k; j++, wnk = wnk * wn0) {
cd x = a[j], y = wnk * a[j + k];
a[j] = x + y, a[j + k] = x - y;
}
}
}
if (dft == -1) {
for (int i = 0; i < lim; i++) {
a[i].real() /= lim;
}
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", a + i);
m = max(m, a[i]);
}
for (lim = 1; lim <= m << 1; lim <<= 1) bit++;
for (int i = 0; i < lim; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
for (int i = 1; i <= n; i++) {
nxt[a[i]]++;
}
for (int r, l = 1; l <= n; l += block) {
r = min(n, l + block - 1);
for (int i = l; i <= r; i++) {
nxt[a[i]]--;
}
// Type I & II
for (int i = l; i <= r; i++) {
for (int j = i + 1; j <= r; j++) {
int k = 2 * a[i] - a[j];
if (1 <= k && k <= m) {
ans += cnt[k] + pre[k];
}
k = 2 * a[j] - a[i];
if (1 <= k && k <= m) {
ans += nxt[k];
}
}
cnt[a[i]]++;
}
// Type III
for (int i = 0; i <= m; i++) {
f[i] = cd(pre[i], 0), g[i] = cd(nxt[i], 0);
}
for (int i = m + 1; i < lim; i++) {
f[i] = g[i] = cd(0, 0);
}
fft(f, 1), fft(g, 1);
for (int i = 0; i < lim; i++) {
f[i] = f[i] * g[i];
}
fft(f, -1);
for (int i = l; i <= r; i++) {
ans += ll(f[2 * a[i]].real() + 0.5);
pre[a[i]]++, cnt[a[i]]--;
}
}
printf("%lld\n", ans);
return 0;
}