算法学习FFT系列(4):任意模数的快速傅里叶变换(MTT)
毛神好强,毛神好强
任意模数的快速傅里叶变换的解法
这里假设序列是
级别的,模数是
级别的。
首先找到问题的瓶颈。
由于模数的任意性,所以NTT失效了。
而由于如果没有模数,最后的结果是在
级别,会太大,所以FFT也不适用。
既然找到了瓶颈,所以我们解决的方针就有两种:(1)把NTT推广到任意模数的形式。(2)通过某种方法分步计算FFT使得精度符合要求。
三模数NTT
根据第一种思路,找到三个符合要求的模数在这三个模数意义下分别FFT然后利用中国剩余定理合并一下。
以后有空再补吧,今天主要介绍另外一种。
拆系数FFT
讲每个数拆成
的形式,其中
是常数
考虑卷积的过程。
这个东西可以FFT出来之后乘
这样子的话,大概估计一下范围,考虑最大的情况。
不难发现,当
的时候,这些东西都是
级别的。
所以FFT出来的结果是
级别的。
这样子的话,我们将整个序列拆成4个序列。做4次DFT,3次IDFT即可。
但是这样有两个坏处,一个是精度不行,还有一个就是7的常数,如果加上FFT本身的常数可能快两个log了。
DFT合并和IDFT合并
这是一个神奇的优化常数的技巧。
最早是Codeforces上的神仙提出的,毛神的论文里有
具体的方法是构造共轭式,我们定义
为了方便,令
考虑DFT后
的关系
秀了一波推到,我们发现只要DFT出P,我们就能得到Q的DFT
然后
这样子的话,两次可以优化到一次。
那我们考虑IDFT,其实只要逆回去就行了。
令
由于IDFT前后都是实数,直接把实部和虚部掏出来即可。
这样子的话,我们可以把两个IDFT合并在一起搞。
这样子的话,MTT就变成了4次DFT了。
代码
//luoguP4245 【模板】MTT
#include<cstdio>
#include<cmath>
#include<algorithm>
const int N = 262144 + 10, M = 32767;
const double pi = acos(-1.0);
typedef long long LL;
int read() {
char ch = getchar(); int f = 1, x = 0;
for(;ch < '0' || ch > '9'; ch = getchar()) if(ch == '-') f = -1;
for(;ch >= '0' && ch <= '9'; ch = getchar()) x = (x << 1) + (x << 3) - '0' + ch;
return x * f;
}
struct cp {
double r, i;
cp(double _r = 0, double _i = 0) : r(_r), i(_i) {}
cp operator * (const cp &a) {return cp(r * a.r - i * a.i, r * a.i + i * a.r);}
cp operator + (const cp &a) {return cp(r + a.r, i + a.i);}
cp operator - (const cp &a) {return cp(r - a.r, i - a.i);}
}w[N], nw[N], da[N], db[N];
cp conj(cp a) {return cp(a.r, -a.i);}
int L, n, m, a[N], b[N], c[N], R[N], P;
void Pre() {
int x = 0; for(L = 1; (L <<= 1) <= n + m; ++x) ;
for(int i = 1;i < L; ++i) R[i] = (R[i >> 1] >> 1) | (i & 1) << x;
for(int i = 0;i < L; ++i) w[i] = cp(cos(2 * pi * i / L), sin(2 * pi * i / L));
}
void FFT(cp *F) {
for(int i = 0;i < L; ++i) if(i < R[i]) std::swap(F[i], F[R[i]]);
for(int i = 2, d = L >> 1;i <= L; i <<= 1, d >>= 1)
for(int j = 0;j < L; j += i) {
cp *l = F + j, *r = F + j + (i >> 1), *p = w, tp;
for(int k = 0;k < (i >> 1); ++k, ++l, ++r, p += d)
tp = *r * *p, *r = *l - tp, *l = *l + tp;
}
}
void Mul(int *A, int *B, int *C) {
for(int i = 0;i < L; ++i) (A[i] += P) %= P, (B[i] += P) %= P;
static cp a[N], b[N], Da[N], Db[N], Dc[N], Dd[N];
for(int i = 0;i < L; ++i) a[i] = cp(A[i] & M, A[i] >> 15);
for(int i = 0;i < L; ++i) b[i] = cp(B[i] & M, B[i] >> 15);
FFT(a); FFT(b);
for(int i = 0;i < L; ++i) {
int j = (L - i) & (L - 1); static cp da, db, dc, dd;
da = (a[i] + conj(a[j])) * cp(0.5, 0);
db = (a[i] - conj(a[j])) * cp(0, -0.5);
dc = (b[i] + conj(b[j])) * cp(0.5, 0);
dd = (b[i] - conj(b[j])) * cp(0, -0.5);
Da[j] = da * dc; Db[j] = da * dd; Dc[j] = db * dc; Dd[j] = db * dd; //顺便区间反转,方便等会直接用DFT代替IDFT
}
for(int i = 0;i < L; ++i) a[i] = Da[i] + Db[i] * cp(0, 1);
for(int i = 0;i < L; ++i) b[i] = Dc[i] + Dd[i] * cp(0, 1);
FFT(a); FFT(b);
for(int i = 0;i < L; ++i) {
int da = (LL) (a[i].r / L + 0.5) % P; //直接取实部和虚部
int db = (LL) (a[i].i / L + 0.5) % P;
int dc = (LL) (b[i].r / L + 0.5) % P;
int dd = (LL) (b[i].i / L + 0.5) % P;
C[i] = (da + ((LL)(db + dc) << 15) + ((LL)dd << 30)) % P;
}
}
int main() {
n = read(); m = read(); P = read();
for(int i = 0;i <= n; ++i) a[i] = read();
for(int j = 0;j <= m; ++j) b[j] = read();
Pre(); Mul(a, b, c);
for(int i = 0;i <= n + m; ++i) printf("%d ", (c[i] + P) % P); puts("");
return 0;
}