Spoj 8372 Triple Sums

题意:给你n个数字,对于任意s,s满足\(s=u_i+u_j+u_k,i<j<k\),要求出所有的s和对应满足条件的i,j,k的方案数

Solution:

构造一个函数:\(A(x)=\sum_{i=0}^{n-1}a_ix^i\),这是一个多项式

对于每一个\(u_i\),我们把这个多项式中的\(x^{u_i}\)的系数\(a_{u_i}\)加上一

也就是说,对于任意\(x^i\),它的系数为i在给出序列中出现的次数

多项式的三次方为:
\[ C(x)=A(x)^3\\ C(x)=\sum_{i=0}^{3n}c_ix^i\\ c_i=\sum_{0\le l,j,k\le n,l+j+k=i}a_ja_ka_l \]

在不考虑\(i<j<k\)的限制条件下,对于任意s,构成s的方案数就是\(C(x)\)\(x^s\)的系数\(c_s\)

我们再来考虑容斥去重将不符合要求的方案给去掉

考虑当\(i,j,k\)中有两个数相同时,构建多项式:\(B(x)=\sum_{i=0}^{n-1}b_ix^i\)

其中对于任意\(x^i\),它的系数\(b_i\)\(i/2(i\,mod\,2=0)\)在序列中出现的次数

则对于多项式:\(D(x)=A(x)B(x)\),它的系数就是两数相同的情况的方案数

\(C(x)\)中它被多加了三次,但减去之后,我们显然可以发现我们将\(i=j=k\)的情况多减了一次

加上后,就得到了不考虑\(i<j<k\)时,\(i\ne j\ne k\)的所有方案数,此时再考虑\(i\le j\le k\),只需把方案数除以6就行了

Code:

#include<bits/stdc++.h>
#define ll long long
#define Pi acos(-1.0)
using namespace std;
const int N=1<<17;
int n,len,tim=17,rtt[N],c[N];
struct cp{double x,y;}aa[N],bb[N],cc[N];
cp operator + (cp a,cp b){return (cp){a.x+b.x,a.y+b.y};}
cp operator - (cp a,cp b){return (cp){a.x-b.x,a.y-b.y};}
cp operator * (cp a,cp b){return (cp){a.x*b.x-a.y*b.y,a.y*b.x+a.x*b.y};}
void FFT(cp *a,int flag){
    for(int i=0;i<len;i++)
        if(i<rtt[i]) swap(a[i],a[rtt[i]]);
    for(int l=2;l<=len;l<<=1){
        cp wn=(cp){cos(flag*2*Pi/l),sin(flag*2*Pi/l)};
        for(int st=0;st<len;st+=l){
            cp w=(cp){1,0};
            for(int u=st;u<st+(l>>1);u++,w=w*wn){
                cp x=a[u],y=w*a[u+(l>>1)];
                a[u]=x+y,a[u+(l>>1)]=x-y;
            }
        }
    }
}
int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-f;ch=getchar();}
    while(isdigit(ch)){x=x*10+ch-48;ch=getchar();}
    return x*f;
}
int main(){
    n=read(),len=N;
    for(int i=1;i<=n;i++){
        int x=read()+20000;
        aa[x].x=aa[x].x+1;
        bb[x<<1].x=bb[x<<1].x+1;
        c[x+x+x]++;
    }
    for(int i=0;i<len;i++)
        rtt[i]=(rtt[i>>1]>>1)|((i&1)<<(tim-1));
    FFT(aa,1);FFT(bb,1);
    for(int i=0;i<len;i++)
        cc[i]=aa[i]*(aa[i]*aa[i]-(cp){3,0}*bb[i]);
    FFT(cc,-1);
    for(int i=0;i<N;i++){
        ll cnt=((ll){cc[i].x/len+0.5}+2*c[i])/6;
        if(cnt) printf("%d : %lld\n",i-60000,cnt);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/NLDQY/p/10758961.html