FFT多项式乘法

[LuoguP3803]

学了好久才懂了那么一点点哎

Code:

 1 #include <bits/stdc++.h>
 2 #define ll long long
 3 using namespace std;
 4 const int N = 1e7 + 7;
 5 const double Pi = acos(-1.0);
 6 ll read() {
 7     ll re = 0, f = 1;
 8     char ch = getchar();
 9     while (ch < '0' || ch > '9') {if (ch == '-') f = -f; ch = getchar();}
10     while ('0' <= ch && ch <= '9') {re = re * 10 + ch - '0'; ch = getchar();}
11     return re * f;
12 }
13 int n, m;
14 int l, pos[N], limit = 1;
15 struct Complex{
16     double x, y;
17     Complex (double nx = 0, double ny = 0) {x = nx, y = ny;}
18 }a[N], b[N];
19 Complex operator +(Complex a, Complex b) {return Complex(a.x + b.x, a.y + b.y);}
20 Complex operator -(Complex a, Complex b) {return Complex(a.x - b.x, a.y - b.y);}
21 Complex operator *(Complex a, Complex b) {return Complex(a.x*b.x-a.y*b.y, a.x*b.y+b.x*a.y);}
22 void FFT(Complex *A, int f) {
23     for (int i = 0; i < limit; i++) {
24         if (i < pos[i]) swap(A[i], A[pos[i]]);
25     }
26     for (int mid = 1; mid < limit; mid <<= 1) {
27         Complex wn(cos(Pi / mid), f * sin(Pi / mid));
28         for (int r = mid << 1, i = 0; i < limit; i += r) {
29             Complex w(1, 0);
30             for (int k = 0; k < mid; k++, w = w * wn) {
31                 Complex u = A[i + k], v = w * A[i + k + mid];
32                 A[i + k] = u + v;
33                 A[i + k + mid] = u - v;
34             }
35         }
36     }
37     if (f == -1) {//Èç¹ûÊÇÄæ±ä»»Òª/n 
38         for (int i = 0; i <= n + m; i++) {
39             a[i].x /= limit;
40         }
41     }
42 }
43 int main () {
44     n = read(), m = read();
45     for (int i = 0; i <= n; i++) {
46         a[i].x = read();
47     }
48     for (int i = 0; i <= m; i++) {
49         b[i].x = read();
50     }
51     while (limit <= n + m) l++, limit <<= 1;
52     for (int i = 0; i < limit; i++) {
53         pos[i] = (pos[i >> 1] >> 1) | ((i & 1) << (l - 1));
54     }
55     FFT(a, 1), FFT(b, 1);
56     for (int i = 0; i < limit; i++) {
57         a[i] = a[i] * b[i];
58     }
59     FFT(a, -1);
60     for (int i = 0; i <= n + m; i++) {
61         printf("%d%c", (int)(a[i].x + 0.5), i == n + m ? '\n' : ' ');
62     }
63     return 0;
64 }
View Code

猜你喜欢

转载自www.cnblogs.com/Sundial/p/12198653.html