版权声明:未经作者本人允许禁止转载。 https://blog.csdn.net/jokerwyt/article/details/81626659
解决模数M不是NTT模数的情况。
多模数NTT
一般取三个模数p1p2p3做NTT,要求满足
,即CRT模数比结果序列值要大。
然后用中国剩余定理(CRT)合并出值。
但是由于三个模数乘起来爆long long了,我们需要一些特殊trick。
首先将两个模数合并,方程变为两条
也即
两边同时模上 ,求p1的逆元后可以计算出k在模p2意义下的值。
又因为我们所求是
也就是
右边括号部分展开一下,不难发现为(
又因为值域小于p1p2,这就是原始值。
从而可以计算出x。求出k后所有操作都应在 意义下进行。
共 次NTT。。。心态是不是有点崩。
据说立大爷的做法是将三模数换成一大一小模数再用O(1)黑科技乘,这样可以做到6次。
mtt
毛爷爷用拆系数fft的方式来代替ntt.
我们考虑直接将两个多项式用fft卷积,发现值域是
,超出double精度范围了。
因此设一个阈值K(通常为
)
将两个多项式每一项的系数拆分为
做fft.
化简后
将AC做卷积后,对应系数为乘上 (这个时候浮点数转为整数)加到答案中去。
其他类推。
将BC+AD放到一个多项式里idft,数一下是7次dft。
然而,因为ACidft回去后的虚部是空的,可以将 加到AC中,这样一起idft回去,虚部就是BD。少一次dft。
当阈值为
,长度为
时,可以发现值域是
的,符合double精度范围。
注意单位根不能递推求,否则精度误差呈指数级上升。
虽然理论上精度没毛病,但实际上依旧会有较大误差,需要加上一个0.5来进行四舍五入。
(好不靠谱的感觉)
推式子将两次DFT缩成一次
假如我当前要求A,B的dft,那么将B放到A的虚部中,称作Q 做一次DFT。
对于A或B的位置i,设其值为x,那么根据dft的意义,对DFT后第w位的贡献是
g是复数根。对于位置w,就相当于我们要求
(theta是那个单位根的i次方对应的角度。)
看看Q[i]DFT后对A’[w]和B’[w]的贡献是什么。
发现我们可以同时知道
他们分别在Qdft后的Q’[w]的实部和虚部。 (注意上面是对于任意一个i的,即Q[w]其实是很多个上述式子加起来的,但这更方便了我们一同处理)
同理考虑Q[i]对A’[-w]和B’[-w]的贡献是什么。 (-w即(M-w)%m)
和上面的式子联立,使用小学生加减便可以算出每一个 与 的和。
将后者乘上 再加起来便是对应的A’与B’了。
//myy's fft
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <complex>
#include <cmath>
using namespace std;
typedef long long ll;
typedef double db;
typedef complex<db> com;
const int N = 3e5;
const db pi = acos(-1);
ll f[N],g[N],n,m,p,h[N],M,K;
com A[N],B[N],C[N],D[N],E[N],ans[N],w[N],iw[N];
void split(ll *s,com *a,com *b) {
for (int i = 0; i < M; i++)
a[i].real(s[i] >> 15), b[i].real(s[i] & ((1 << 15)-1));
}
void dft(com *a,int sig) {
for (int i = 0; i < M; i++) if (h[i] < i) swap(a[i],a[h[i]]);
int r = 1,step = M;
for (int m = 2; m <= M; m<<=1, r++) {
int hf = m >> 1, z = 0; step >>= 1;
for (int i = 0; i < hf; i++, z+=step) {
com ww = (sig == 1 ? w[z] : iw[z]);
for (int j = i; j < M; j += m) {
com u = a[j], v = a[j + hf] * ww;
a[j] = u + v, a[j + hf] = u - v;
}
}
}
if (sig == -1)
for (int i = 0; i < M; i++) a[i] /= M;
}
com Q[N];
void mdft(com *A,com *B) {
for (int i = 0; i < M; i++) Q[i] = com(A[i].real(),B[i].real());
dft(Q, 1);
for (int i = 0; i < M; i++) {
int j = i ? M - i : 0;
com sA=0,cA=0,sB=0,cB=0;
sA.imag() = (Q[i].imag() - Q[j].imag()) * 0.5;
cA.real() = (Q[i].real() + Q[j].real()) * 0.5;
sB.imag() = (Q[j].real() - Q[i].real()) * 0.5;
cB.real() = (Q[i].imag() + Q[j].imag()) * 0.5;
A[i] = sA + cA;
B[i] = sB + cB;
}
}
int main() {
freopen("a.in","r",stdin);
cin>>n>>m>>p; K = 1<<15;
for (int i = 0; i <= n; i++) scanf("%lld",&f[i]);
for (int i = 0; i <= m; i++) scanf("%lld",&g[i]);
for (M=1; M <= n+m; M<<=1);
split(f,A,B);
split(g,C,D);
for (int j = 0; j < M; j++) {
if (j==0) w[j] = iw[j] = 1; else {
db c = cos(2*pi/M*j), s = sin(2*pi/M*j);
w[j] = com(c, s);
iw[j] = com(c, -s);
}
}
for (int i = 1; i < M; i++) h[i] = (h[i>>1]>>1) + (i&1) * (M>>1);
mdft(A,B);
mdft(C,D);
for (int i = 0; i < M; i++)
E[i] = A[i] * D[i] + B[i] * C[i];
for (int i = 0; i < M; i++) A[i] *= C[i];
for (int i = 0; i < M; i++) {
A[i] += B[i] * D[i] * com(0,1);
}
dft(A,-1),dft(E,-1);
ll k2 = K * K % p;
for (int i = 0; i <= n+m && i <= 1000000; i++) {
A[i]+=com(0.5,0.5);
E[i]+=com(0.5,0);
ll w = (ll)A[i].real() % p * k2 % p + (ll)A[i].imag() % p + (ll)E[i].real() % p * K % p;
printf("%lld ",w % p);
}
}