任意模数NTT求卷积

版权声明:未经作者本人允许禁止转载。 https://blog.csdn.net/jokerwyt/article/details/81626659

解决模数M不是NTT模数的情况。

多模数NTT

一般取三个模数p1p2p3做NTT,要求满足 p 1 p 2 p 3 > n M 2 ,即CRT模数比结果序列值要大。
然后用中国剩余定理(CRT)合并出值。
但是由于三个模数乘起来爆long long了,我们需要一些特殊trick。

首先将两个模数合并,方程变为两条

x a 1 mod p 1

x a 2 mod p 2

也即
x = a 1 + k p 1 = a 2 + z p 2

两边同时模上 p 2 ,求p1的逆元后可以计算出k在模p2意义下的值。

又因为我们所求是

( a 1 + k p 1 ) mod p 1 p 2

也就是
( a 1 + ( k p 1 mod p 1 p 2 ) ) mod p 1 p 2

右边括号部分展开一下,不难发现为( 这是常识然而我并不会)
p 1 ( k mod p 2 )

又因为值域小于p1p2,这就是原始值。
从而可以计算出x。求出k后所有操作都应在 mod M 意义下进行。

9 次NTT。。。心态是不是有点崩。

据说立大爷的做法是将三模数换成一大一小模数再用O(1)黑科技乘,这样可以做到6次。

mtt

毛爷爷用拆系数fft的方式来代替ntt.
我们考虑直接将两个多项式用fft卷积,发现值域是 10 23 ,超出double精度范围了。

因此设一个阈值K(通常为 2 15 )
将两个多项式每一项的系数拆分为 a K + b 做fft.

( K A + B ) ( K C + D )

化简后
K 2 A C + K ( B C + A D ) + B D

将AC做卷积后,对应系数为乘上 K 2 (这个时候浮点数转为整数)加到答案中去。
其他类推。

将BC+AD放到一个多项式里idft,数一下是7次dft。

然而,因为ACidft回去后的虚部是空的,可以将 B D i 加到AC中,这样一起idft回去,虚部就是BD。少一次dft。

当阈值为 2 15 ,长度为 1 e 5 时,可以发现值域是 1 e 14 的,符合double精度范围。
注意单位根不能递推求,否则精度误差呈指数级上升。
虽然理论上精度没毛病,但实际上依旧会有较大误差,需要加上一个0.5来进行四舍五入。
(好不靠谱的感觉)

推式子将两次DFT缩成一次

假如我当前要求A,B的dft,那么将B放到A的虚部中,称作Q 做一次DFT。
对于A或B的位置i,设其值为x,那么根据dft的意义,对DFT后第w位的贡献是

x ( g n w ) i

g是复数根。对于位置w,就相当于我们要求
x     ( cos θ + I sin θ )

(theta是那个单位根的i次方对应的角度。)
看看Q[i]DFT后对A’[w]和B’[w]的贡献是什么。
( A i + I B i ) ( g n w ) i

= ( A i + I B i ) ( cos θ + I sin θ )

发现我们可以同时知道
A i cos B i sin B i cos + A sin

他们分别在Qdft后的Q’[w]的实部和虚部。 (注意上面是对于任意一个i的,即Q[w]其实是很多个上述式子加起来的,但这更方便了我们一同处理)
同理考虑Q[i]对A’[-w]和B’[-w]的贡献是什么。 (-w即(M-w)%m)
A i cos + B i sin B i cos A sin

和上面的式子联立,使用小学生加减便可以算出每一个 A i cos A i sin 的和。
将后者乘上 I 再加起来便是对应的A’与B’了。

//myy's fft
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <complex>
#include <cmath>
using namespace std;
typedef long long ll;
typedef double db;
typedef complex<db> com;

const int N = 3e5;
const db pi = acos(-1);

ll f[N],g[N],n,m,p,h[N],M,K;
com A[N],B[N],C[N],D[N],E[N],ans[N],w[N],iw[N];

void split(ll *s,com *a,com *b) {
    for (int i = 0; i < M; i++) 
        a[i].real(s[i] >> 15), b[i].real(s[i] & ((1 << 15)-1));
}

void dft(com *a,int sig) {
    for (int i = 0; i < M; i++) if (h[i] < i) swap(a[i],a[h[i]]);
    int r = 1,step = M;
    for (int m = 2; m <= M; m<<=1, r++) {
        int hf = m >> 1, z = 0; step >>= 1;
        for (int i = 0; i < hf; i++, z+=step) {
            com ww = (sig == 1 ? w[z] : iw[z]);
            for (int j = i; j < M; j += m) {
                com u = a[j], v = a[j + hf] * ww;
                a[j] = u + v, a[j + hf] = u - v;
            }
        }
    }
    if (sig == -1) 
        for (int i = 0; i < M; i++) a[i] /= M;
}

com Q[N];
void mdft(com *A,com *B) {
    for (int i = 0; i < M; i++) Q[i] =  com(A[i].real(),B[i].real());
    dft(Q, 1);
    for (int i = 0; i < M; i++) {
        int j = i ? M - i : 0;
        com sA=0,cA=0,sB=0,cB=0;
        sA.imag() = (Q[i].imag() - Q[j].imag()) * 0.5;
        cA.real() = (Q[i].real() + Q[j].real()) * 0.5;

        sB.imag() = (Q[j].real() - Q[i].real()) * 0.5;
        cB.real() = (Q[i].imag() + Q[j].imag()) * 0.5;
        A[i] = sA + cA;
        B[i] = sB + cB;
    }
}

int main() {
    freopen("a.in","r",stdin);
    cin>>n>>m>>p; K = 1<<15;
    for (int i = 0; i <= n; i++) scanf("%lld",&f[i]);
    for (int i = 0; i <= m; i++) scanf("%lld",&g[i]);
    for (M=1; M <= n+m; M<<=1);
    split(f,A,B);
    split(g,C,D);
    for (int j = 0; j < M; j++) {
        if (j==0) w[j] = iw[j] = 1; else {
            db c = cos(2*pi/M*j), s = sin(2*pi/M*j);
            w[j] = com(c, s);
            iw[j] = com(c, -s);
        }
    }
    for (int i = 1; i < M; i++) h[i] = (h[i>>1]>>1) + (i&1) * (M>>1);
    mdft(A,B);
    mdft(C,D);
    for (int i = 0; i < M; i++)
        E[i] = A[i] * D[i] + B[i] * C[i];
    for (int i = 0; i < M; i++) A[i] *= C[i];
    for (int i = 0; i < M; i++) {
        A[i] += B[i] * D[i] * com(0,1);
    }
    dft(A,-1),dft(E,-1);
    ll k2 = K * K % p;
    for (int i = 0; i <= n+m && i <= 1000000; i++) {
        A[i]+=com(0.5,0.5);
        E[i]+=com(0.5,0);
        ll w = (ll)A[i].real() % p * k2 % p + (ll)A[i].imag() % p + (ll)E[i].real() % p * K % p;
        printf("%lld ",w % p);
    }
}

猜你喜欢

转载自blog.csdn.net/jokerwyt/article/details/81626659