ABC196F-FFT/NTT加速卷积运算

参考博客

前言:

这题跟之前一道湖南对抗赛的题很像。都是FFT加速优化.这也是我第一道FFT的题目.

FFT/NTT

该算法有四个精髓点:

1.将系数表示法转化为点值表示法。
2.利用函数奇偶对称性+分治+引入复数域单位根实现快速求 n n n个点值.
3.利用矩阵的逆将点值表示法转化成系数表示法.
4.根据逆矩阵的特征继续跑一遍FFT.(IFFT)

NTT:在模意义下的快速变化.将单位根换成原根进行运算.其他一致.由于避免了浮点运算.速度快于FFT.

题目大意:

给你两个二进制字符串 S , T S,T S,T.问你最少修改 T T T中多少个字符使得 T T T S S S中出现.
∣ S ∣ , ∣ T ∣ ≤ 1 e 6 |S|,|T| \leq 1e6 S,T1e6

题目思路:

暴力 O ( n m ) O(nm) O(nm),显然不行,引入多项式算法.

f ( i ) f(i) f(i)代表 S [ i , i + m − 1 ] S[i,i+m-1] S[i,i+m1] T T T不同位个数.题目求 min ⁡ m n − m + 1 f ( i ) \min_{m}^{n-m+1}f(i) minmnm+1f(i).

现在的问题是,如何快速求 f f f函数?

f ( i ) = ∑ j = 0 m − 1 ( S i + j − T j ) 2 = ∑ j = 0 m − 1 S i + j 2 + ∑ j = 0 m − 1 T j 2 + ∑ j = 0 m − 1 S i + j T j f(i)=\sum_{j=0}^{m-1}(S_{i+j}-T_{j})^2=\sum_{j=0}^{m-1}S_{i+j}^2+\sum_{j=0}^{m-1}T_j^2+\sum_{j=0}^{m-1}S_{i+j}T_{j} f(i)=j=0m1(Si+jTj)2=j=0m1Si+j2+j=0m1Tj2+j=0m1Si+jTj.

前两项前缀和优化,后一项用多项式优化:不妨反转 T T T.第三项变为:
∑ j = 0 m − 1 S i + j T m − 1 − j \sum_{j=0}^{m-1}S_{i+j}T_{m-1-j} j=0m1Si+jTm1j

这就是一个经典的卷积形式. N T T NTT NTT跑多项式乘法即可.

PS:这里为啥可以直接算?因为两个多项式相乘,得到的 x k , ∣ T ∣ ≤ k x_{k}, |T| \leq k xk,Tk,它的系数就是一个卷积形式 ∑ j = 0 m − 1 S k − j T j \sum_{j=0}^{m-1}S_{k-j}T_{j} j=0m1SkjTj --这里需要忽略超出 T T T下标的部分.所以对累和上界取 m i n min min.

AC代码:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int,int>
#define pb push_back
#define mp make_pair
#define vi vector<int>
#define vll vector<ll>
#define fi first
#define se second
const int maxn = 1e6 + 5;
const int mod = 1e9 + 7;
namespace NTT
{
    
    
    const int P=998244353,g=3;
    const int W=22,S=1<<W;
    const int J=86583718;

    inline int add(int a,int b) {
    
    int r=a+b; return r<P?r:r-P;}
    inline int sub(int a,int b) {
    
    int r=a-b; return r<0?r+P:r;}
    inline int mul(long long a,long long b) {
    
    return (a*b)%P;}
    inline int inv(int a) {
    
    return a==1?a:mul(inv(P%a),P-P/a);}
    inline int qpow(int a,long long k)
    {
    
    
        int r=1;
        while (k)
        {
    
    
            if (k&1) r=mul(r,a);
            k>>=1; a=mul(a,a);
        }
        return r;
    }

    const int i2=inv(2),ij=inv(J);

    int r[S],w[2][S];
    void init(int lim)
    {
    
    
        int w0=qpow(g,(P-1)/lim);
        w[0][0]=w[1][0]=1;
        for (int i=1;i<lim;i++) w[0][i]=w[1][lim-i]=mul(w[0][i-1],w0);
        for (int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)*(lim>>1));
    }

    void ntt(int *a,int lim,int o)
    {
    
    
        for (int i=0;i<lim;i++) if (i<r[i]) swap(a[i],a[r[i]]);
        for (int i=1;i<lim;i<<=1)
        {
    
    
            for (int j=0,t=lim/(i<<1);j<lim;j+=i<<1)
            {
    
    
                for (int k=j,l=0;k<j+i;k++,l+=t)
                {
    
    
                    int x=a[k],y=mul(w[o][l],a[k+i]);
                    a[k]=add(x,y);
                    a[k+i]=sub(x,y);
                }
            }
        }
        if (o)
        {
    
    
            int tmp=NTT::inv(lim);
            for (int i=0;i<lim;i++) a[i]=mul(a[i],tmp);
        }
    }

    int p1[S],p2[S];
    vector<int>poly_mul(const vector<int>&a,const vector<int>&b)
    {
    
    
        int n=a.size(),m=b.size();
        int lim=1;
        while (lim<(n<<1)) lim<<=1;
        while (lim<(m<<1)) lim<<=1;
        init(lim);
        copy_n(a.begin(),n,p1); fill(p1+n,p1+lim,0);
        copy_n(b.begin(),m,p2); fill(p2+m,p2+lim,0);
        ntt(p1,lim,0);
        ntt(p2,lim,0);
        for (int i=0;i<lim;i++) p1[i]=mul(p1[i],p2[i]);
        ntt(p1,lim,1);
        return vector<int>(p1,p1+n+m-1);
    }
}
int st[maxn] , ss[maxn];
int main()
{
    
    
    ios::sync_with_stdio(false);
    string a , b;
    cin >> a >> b;
    reverse(b.begin(),b.end());
    int n = a.size() , m = b.size();
    for (int i = 1 ; i <= n ; i++)
        ss[i] = ss[i - 1] + (a[i - 1] - '0');
    for (int i = 1 ; i <= m ; i++)
        st[i] = st[i - 1] + (b[i - 1] - '0');
    vector<int> fs , ft;
    for (auto g : a) fs.pb(g - '0');
    for (auto g : b) ft.pb(g - '0');
    auto res = NTT::poly_mul(fs , ft);
    int ans = 1e9;
    //for (auto g : res) cout << g << " ";
    //cout << endl;
    for (int i = m ; i <= n ; i++){
    
    
        int sum = ss[i] - ss[i - m];
        ans = min(ans , sum + st[m] - 2 * res[i - 1]);
    }
    cout << ans << endl;
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_35577488/article/details/115064462
今日推荐