算法笔记--FFT

推荐阅读资料:算法导论第30章

本文不做证明,详细证明请看如上资料。

FFT在算法竞赛中主要用来加速多项式的乘法

普通是多项式乘法时间复杂度的是O(n2),而用FFT求多项式的乘法可以使时间复杂度达到O(nlogn)

FFT求多项式的乘法步骤主要如下图

其中求值是将系数表达转换成点值表达,带入的自变量是wn=1的复数解,称为DFT

插值是将点值表达转换成系数表达,称为DFT-1

DFT 和 DFT-1都可以用FFT加速实现

这是递归版的FFT

还有一种非递归的版本

我们发现叶子节点的下表的二进制为:000   100   010   110    001  101   110    111

与它们的本身所对应的位置的二进制:000   001  010   011    100   101    011   111

相反

所以我们可以确定叶子节点的值,从下往上进行操作

求二进制反转的代码(其中L是二进制位):

for (int i = 0; i < n; i++) {
            R[i] = (R[i>>1]>>1) | ((i&1) << L-1);
        }

假设现在R[i]的二进制是abcd,没有操作之前的R[i>>1]是0abc,操作之后的是cba0,再右移是0cba,再判断原来的d是不是1在最高位放1或0,就刚好是反转的结果

模板:

递归版(以求大数乘法为例):

#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
#define mp make_pair
#define pb push_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pii pair<int, int>
#define piii pair<int,pii>
#define mem(a, b) memset(a, b, sizeof(a))
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout);
//head

typedef complex<double> cd;
const int N = 2e5 + 5;
char a[N], b[N];
cd A[N], B[N];
int tmp[N];
void fft(cd *x, int n, int type) {
    if(n == 1) return ;
    cd l[n>>1], r[n>>1];
    for (int i = 0; i < n; i += 2) {
        l[i>>1] = x[i];
        r[i>>1] = x[i+1];
    }
    fft(l, n>>1, type);
    fft(r, n>>1, type);
    cd wn(cos(2*pi/n), sin(type*2*pi/n)), w(1, 0), t;
    for(int i = 0; i < n>>1; i++, w *= wn) {
        t = w*r[i];
        x[i] = l[i] + t;
        x[i+(n>>1)] = l[i] - t;
    }
}
int main() {
    while(~scanf("%s%s", a, b)) {
        int n = strlen(a), m = strlen(b);
        mem(A, 0);
        mem(B, 0);
        mem(tmp, 0);
        for (int i = n - 1; i >= 0; i--) A[n-1-i] = a[i] - '0';
        for (int i = m - 1; i >= 0; i--) B[m-1-i] = b[i] - '0';
        m = m + n;
        for(n = 1; n <= m; n <<= 1);
        fft(A, n, 1);
        fft(B, n, 1);
        for (int i = 0; i < n; i++) A[i] *= B[i];
        fft(A, n, -1);
        for (int i = 0; i < m; i++) {
            int t = (int)(A[i].real()/n + 0.5);
            t += tmp[i];
            tmp[i] = t%10;
            tmp[i+1] += t/10;
        }
        int i;
        for (i = m; i >= 1; i--) if(tmp[i]) break;
        for (i; i >= 0; i--) printf("%d", tmp[i]);
        printf("\n");
    }
    return 0;
}

非递归版(以求大数乘法为例):

#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pi acos(-1.0)
#define LL long long
#define mp make_pair
#define pb push_back
#define ls rt<<1, l, m
#define rs rt<<1|1, m+1, r
#define ULL unsigned LL
#define pll pair<LL, LL>
#define pii pair<int, int>
#define piii pair<int,pii>
#define mem(a, b) memset(a, b, sizeof(a))
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout);
//head

typedef complex<double> cd;
const int N = 2e5 + 5;
char a[N], b[N];
cd A[N], B[N];
int tmp[N], R[N];
void fft(cd *x, int n, int type) {
    for (int i = 0; i < n; i++) if(i < R[i]) swap(x[i], x[R[i]]);
    for (int i = 1; i < n; i <<= 1) {
        cd wn(cos(pi/i), type*sin(pi/i));
        for (int j = 0; j < n; j += i<<1) {
            cd w(1, 0);
            for (int k = 0; k < i; k++, w*=wn) {
                cd X = x[j+k], Y = w*x[j+k+i];
                x[j+k] = X+Y;
                x[j+k+i] = X-Y;
            }
        }
    }
}
int main() {
    while(~scanf("%s%s", a, b)) {
        int n = strlen(a), m = strlen(b), L = 0;
        mem(A, 0);
        mem(B, 0);
        mem(tmp, 0);
        mem(R, 0);
        for (int i = n - 1; i >= 0; i--) A[n-1-i] = a[i] - '0';
        for (int i = m - 1; i >= 0; i--) B[m-1-i] = b[i] - '0';
        m = m + n;
        for(n = 1; n <= m; n <<= 1) L++;
        for (int i = 0; i < n; i++) {
            R[i] = (R[i>>1]>>1) | ((i&1) << L-1);
        }
        fft(A, n, 1);
        fft(B, n, 1);
        for (int i = 0; i < n; i++) A[i] *= B[i];
        fft(A, n, -1);
        for (int i = 0; i < m; i++) {
            int t = (int)(A[i].real()/n + 0.5);
            t += tmp[i];
            tmp[i] = t%10;
            tmp[i+1] += t/10;
        }
        int i;
        for (i = m; i >= 1; i--) if(tmp[i]) break;
        for (i; i >= 0; i--) printf("%d", tmp[i]);
        printf("\n");
    }
    return 0;
}

PS:手写complex类+非递归版最快

猜你喜欢

转载自www.cnblogs.com/widsom/p/9152440.html