[模板]高精度乘法

FFT优化多项式乘法。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstring>
using namespace std;
#define reg register
inline int read() {
    int res = 0;char ch = getchar();bool fu = 0;
    while(!isdigit(ch)) fu |= (ch == '-'), ch = getchar();
    while(isdigit(ch)) res = (res << 3) + (res << 1) + (ch ^ 48), ch = getchar();
    return fu ? - res : res;
}

namespace BriMon
{
#define N 300005
const double Pi = 3.14159265358979323846264338327950;
int n, m, rev[N<<2];
struct cmplx {
    double x, y;
    cmplx() { }
    cmplx(double aa, double bb) {x = aa, y = bb;}
    friend cmplx operator + (cmplx a, cmplx b) { return cmplx(a.x + b.x, a.y + b.y); }
    friend cmplx operator - (cmplx a, cmplx b) { return cmplx(a.x - b.x, a.y - b.y); }
    friend cmplx operator * (cmplx a, cmplx b) { return cmplx(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);}
} b[N<<2], c[N<<2];

void fft(cmplx *f, int op) 
{
    for (reg int i = 0 ; i < n ; i ++) if (i < rev[i]) swap(f[i], f[rev[i]]);
    for (reg int p = 2 ; p <= n ; p <<= 1) 
    {
        int len = p >> 1;
        cmplx tmp = cmplx(cos(Pi / len), op * sin(Pi / len));
        for (reg int k = 0 ; k < n ; k += p) 
        {
            cmplx w = cmplx(1, 0);
            for (reg int l = k ; l < k + len ; l ++) 
            {
                cmplx __tmp = w * f[l + len];
                f[l + len] = f[l] - __tmp;
                f[l] = f[l] + __tmp;
                w = w * tmp;
            }
        }
    }
}
char s[N];
int ans[N<<1];
int L;

int main()
{
    n = read() - 1;
    scanf("%s", s);
    for (reg int i = 0 ; i <= n ; i ++) b[i].x = s[n - i] - '0';
    scanf("%s", s);
    for (reg int i = 0 ; i <= n ; i ++) c[i].x = s[n - i] - '0';
    m = n << 1;
    for (n = 1 ; n <= m ; n <<= 1) ++L;
    for (reg int i = 0 ; i < n ; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));
    fft(b, 1), fft(c, 1);
    for (reg int i = 0 ; i < n ; i ++) b[i] = b[i] * c[i];
    fft(b, -1);
    for (reg int i = 0 ; i <= m ; i ++) ans[i] = (int) (b[i].x / n + 0.1);
    for (reg int i = 0 ; i <= m ; i ++) 
        if (ans[i] >= 10) {
            ans[i + 1] += ans[i] / 10, ans[i] %= 10;
            if (i == m)    ++m;
        }
    while(m) if (ans[m]) break;else m--;
    for (reg int i = m ; i >= 0 ; i --) printf("%d", ans[i]);
    return 0;
}
}

int zZh = BriMon :: main();
int main() {return 0;}

猜你喜欢

转载自www.cnblogs.com/BriMon/p/10127672.html