「BZOJ3509」「CodeChef」 COUNTARI

Description

给定一个长度为 N 的数组 A ,求有多少对 i,j,k (1i<j<kN) 满足 AkAj=AjAi

Input

第一行一个整数 N
接下来一行 N 个数 Ai

Output

一行一个整数。

Sample Input

10
3 5 3 6 3 4 10 4 5 2

Sample Output

9

HINT

N105,Ai30000

题解

以下记 W=max{A1,,AN}
这道题的 O(NW) 做法非常显然,即维护每个数左边和右边每种数字出现的次数。但这也是这道题的瓶颈所在。因为很显然这个算法是很难(或者不可能)继续优化下去的,所以很可能会卡在这里(如果你之前没做过类似的题目)。
这样,我们就考虑从一个看起来时间复杂度更坏的算法入手。
由于三个数构成等差数列,所以 2Aj=Ai+Ak 。我们可以对于每一个数维护左边和右边每种数字出现的次数,这个可以做到 O(N) 。然后统计方案数可以用卷积来实现,用 FFT 可以做到 O(Wlog2W) ,于是总复杂度为 O(NWlog2W) 。很明显是更差的。但是这个算法就少了很多局限性。
考虑我们卷积的过程,设多项式 f(x) 表示下标在区间 [L1,R1] 的生成函数( xk 的系数表示数字 k 出现的次数); g(x) 表示下标在区间 [L2,R2] (L1R1<L2R2) 的生成函数( xk 的系数表示数字 k 出现的次数);卷积 (fg)(x) x2k 的系数即为首项下标在 [L1,R1] , 末项下标在 [L2,R2] ,中项为 k 的项数为 3 等差数列数目。很明显,如果我们设 L1=1,R2=N ,我们做一次卷积可以求出首项下标在 [1,R1] , 末项下标在 [L2,N] ,中项下标在区间 (R1,L2) 的等差数列个数。
这样就很明显可以分块来做。
我们把整个区间分成 K 块,枚举每一个数为中项,我们讨论下列三种情况:
1.首项和末项都在块内,可以用刚开始的做法,但是如果块内元素比较小,可以枚举首项下标,这样单块复杂度 O((NK)2)
2.首项和末项有一个在块内,我们可以枚举在块内的那一项,同样可以做到单块 O((NK)2)
3.首项和末项都不在块内,那么我们就需要用卷积了。一次卷积即可求出块内所有元素为中项的方案数。单块复杂度 O(Wlog2W)

那么总的复杂度就是 O(N2K+KWlog2W)
由均值不等式, K=NWlogW 时复杂度最低,为 O(NWlog2W)

但是事实上,由于常数等原因,块的大小需要调大大约 10 倍,约 2000 左右时最快。由于此题卡常严重,需要手写复数类。

My Code

/**************************************************************
    Problem: 3509
    User: infinityedge
    Language: C++
    Result: Accepted
    Time:34664 ms
    Memory:5784 kb
****************************************************************/

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <complex>

#define MAXN 65536
#define pi acos(-1)
using namespace std;
typedef long long ll;
struct E{
    long double real, imag;
    E(long double real = 0, long double imag = 0) : real(real), imag(imag) { }
    inline friend E operator + (E &a, E &b)
        { return E(a.real + b.real, a.imag + b.imag); }
    inline friend E operator - (E &a, E &b)
        { return E(a.real - b.real, a.imag - b.imag); }
    inline friend E operator * (E &a, E &b)
        { return E(a.real * b.real - a.imag * b.imag , a.imag * b.real + a.real * b.imag); }
    inline friend void swap(E &a, E &b)
        { E c = a; a = b; b = c; }
};

E a[MAXN + 1], b[MAXN + 1];

void bit_reverse(int n, E* r){
    for(int i = 0, j = 0; i < n; i ++){
        if(i > j) swap(r[i], r[j]);
        for(int l = n >> 1; (j ^= l) < l; l >>= 1);
    }
}

void fft(int n, E* r, int f){
    bit_reverse(n, r);
    for(int i = 2; i <= n; i <<= 1){
        int m = i >> 1;
        for(int j = 0; j < n; j += i){
            E w(1, 0), wn(cos(2 * pi / i), f * sin(2 * pi / i));
            for(int k = 0; k < m; k ++){
                E z = r[j + m + k] * w;
                r[j + m + k] = r[j + k] - z;
                r[j + k] = r[j + k] + z;
                w = w * wn;
            }
        }
    }
    if(f == -1){
        E ww = E(1.0 / n, 0);
        for(int i = 0; i < n; i ++) r[i] = r[i] * ww;
    }
}

int n, k, m;
int d[100005], pos[100005], l[1005], r[1005];
ll ans;
int vis[30005];
int tmpl[MAXN], tmpr[MAXN];
void solve(int x){
    for(int i = l[x]; i <= r[x]; i ++){
        tmpr[d[i]]++;   
    }
    for(int i = l[x]; i <= r[x]; i ++){
        for(int j = i + 1; j <= r[x]; j ++){
            int dk = d[i] + d[i] - d[j];
            ans += tmpl[dk];
        }
        tmpl[d[i]]++;
    }
    for(int i = l[x]; i <= r[x]; i ++){
        tmpr[d[i]] = tmpl[d[i]] = 0;    
    }
}
int N = 1;
void solsub(int x){

    for(int i = l[x]; i <= r[x]; i ++){
        for(int j = i + 1; j <= r[x]; j ++){
            int dk = d[i] + d[i] - d[j];
            if(dk >= 0) ans += tmpl[dk];
            dk = d[j] + d[j] - d[i];
           if(dk >= 0) ans += tmpr[dk];
        }
    }
     if(x == 1 || x == m) return;
    for(int i = 0; i <= N; i ++){
        a[i] = b[i] = E(0, 0);
    }
    for(int i = 0; i <= N; i ++){
        a[i] = E(tmpl[i], 0);
        b[i] = E(tmpr[i], 0);
    }
    fft(N, a, 1); fft(N, b, 1);
    for(int i = 0; i <= N; i ++){
        a[i] = a[i] * b[i];
    }
    fft(N, a, -1);
    for(int i = l[x]; i <= r[x]; i ++){
        ans = ans + (ll)(a[2 * d[i]].real + 0.1);
    }
}
void solve2(){
    int mx = 0;
    for(int i = 1; i <= n; i ++){
        mx = max(d[i], mx);
    }
    mx = mx * 2 + 1;

    while(N < mx) N = N << 1;
    for(int i = 1; i <= n; i ++){
        tmpr[d[i]] ++;
    }
    for(int i = 1; i <= m; i ++){
        for(int j = l[i]; j <= r[i]; j ++){
            tmpr[d[j]] --;
        }
        solsub(i);
        for(int j = l[i]; j <= r[i]; j ++){
            tmpl[d[j]] ++;
        }
    }
}

int main(){
    scanf("%d", &n); k = 1823;
    if(n < 1823) k = 1823;
    for(int i = 1; i <= n; i ++){
        scanf("%d", &d[i]);
    }
    for(int i = 1; i <= n; i ++){
        pos[i] = (n - 1) / k + 1;
    }
    m = pos[n];
    for(int i = 1; i <= m; i ++){
        l[i] = (i - 1) * k + 1;
        r[i] = i * k;   
    }
    r[m] = n;
    for(int i = 1; i <= m; i ++){
        solve(i);
    }
    solve2();
    printf("%lld\n", ans);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/infinity_edge/article/details/78745766