C++ 拉格朗日插值法优化 DP

拉格朗日插值法

简介

众所周知, n n 个点 ( x i , y i ) (x_i, y_i) (任意两个点横坐标不相等)可以确定一个 n 1 n - 1 次多项式函数 y = f ( x ) y = f(x) 。拉格朗日插值法可以根据这 n n 个点求出这个多项式 f ( x ) f(x) 。当然,实际应用中通常求出横坐标为 k k 的点在该( n n 个点确定的)多项式函数上对应的纵坐标的值,代码实现中我们也只考虑这一问题。

一个直观的想法是利用待定系数法设 f ( x ) = a n 1 x n 1 + + a 0 x 0 f(x) = a_{n- 1} x ^ {n - 1} + \cdots + a_0 x ^ 0 ,然后带入 n n 个点得到一个 n n 元一次方程组,然后用 高斯消元 解得系数。但这个方法和拉格朗日插值法相比有两个问题:一是时间复杂度为 O ( n 3 ) O(n^3) ,而拉格朗日法时间复杂度为 n 2 n^2 ;二就是系数可能解出小数,还可能很大,而拉格朗日法可以支持取模,并跳过“系数”这一中间步骤,直接求值。

拉格朗日插值法

我们以二次函数为例,看一看拉格朗日插值法的具体过程:已知 3 3 个点 ( x 1 , y 1 ) (x_1, y_1) ( x 2 , y 2 ) (x_2, y_2) ( x 3 , y 3 ) (x_3, y_3) ,求 f ( x ) f(x)

拉格朗日(Joseph-Louis Lagrange,1736 ~ 1813)的做法非常巧妙地避开了解多元方程的过程:
f 1 ( x ) f_1(x) 表示经过点 ( x 1 , 1 ) (x_1, 1) ( x 2 , 0 ) (x_2, 0) ( x 3 , 0 ) (x_3, 0) 的二次函数;
f 2 ( x ) f_2(x) 表示经过点 ( x 1 , 0 ) (x_1, 0) ( x 2 , 1 ) (x_2, 1) ( x 3 , 0 ) (x_3, 0) 的二次函数;
f 3 ( x ) f_3(x) 表示经过点 ( x 0 , 1 ) (x_0, 1) ( x 2 , 0 ) (x_2, 0) ( x 3 , 1 ) (x_3, 1) 的二次函数。
那么 f ( x ) = y 1 f 1 ( x ) + y 2 f 2 ( x ) + y 3 f 3 ( x ) f(x) = y_1 \cdot f_1(x) + y_2 \cdot f_2(x) + y_3 \cdot f_3(x)

原因很简单,每个子函数确保经过一个点而不经过另外两个点。

而子函数的求法很简单,以 f 1 ( x ) f_1(x) 为例:
f 1 ( x ) = 0 f_1(x) = 0 的两根为 x = x 2 x = x_2 x = x 3 x = x_3 ,于是设 f 1 ( x ) = k ( x x 2 ) ( x x 3 ) f_1(x) = k (x - x_2) (x - x_3) ,再带入点 ( x 1 , 1 ) (x_1, 1) ,得到 k = 1 ( x 1 x 2 ) ( x 1 x 3 ) k = \frac{1}{(x_1 - x_2)(x_1 - x_3)} ,于是 f 1 ( x ) = ( x x 2 ) ( x x 3 ) ( x 1 x 2 ) ( x 1 x 3 ) f_1(x) = \frac{(x - x_2) (x - x_3)}{(x_1 - x_2)(x_1 - x_3)}

求高次函数与求二次函数的方法同理,可得
f i ( x ) = 1 j n , j i ( x x j ) ( x i x j ) f ( x ) = 1 i n f i ( x ) \begin{aligned} f_i(x) &= \prod_{1 \leq j \leq n, j \neq i} \frac{(x - x_j)}{(x_i - x_j)} \\ f(x) &= \sum_{1 \leq i \leq n} f_i(x) \end{aligned}
于是,想求 f ( k ) f(k) 的值,将 k k 代入上式即可,时间复杂度 O ( n 2 ) O(n^2) n n 为次数)。

模板

洛谷 P4781 【模板】拉格朗日插值

#include <bits/stdc++.h>

const int MOD = 998244353;
const int MAXN = 2000;

int Mul(const int &a, const int &b) {
    return (long long)a * b % MOD;
}

int Inv(int x) {
    int y = MOD - 2, ret = 1;
    while (y) {
        if (y & 1)
            ret = Mul(ret, x);
        x = Mul(x, x);
        y >>= 1;
    }
    return ret;
}

int N, K, X[MAXN + 5], Y[MAXN + 5];

int main() {
    scanf("%d%d", &N, &K); // 求 f(K)
    for (int i = 1; i <= N; i++)
        scanf("%d%d", &X[i], &Y[i]);
    int Ans = 0;
    for (int i = 1; i <= N; i++) {
        int x = Y[i], y = 1;
        for (int j = 1; j <= N; j++)
            if (j != i) {
                x = Mul(x, (K - X[j] + MOD) % MOD);
                y = Mul(y, (X[i] - X[j] + MOD) % MOD);
            }
        Ans = (Ans + Mul(x, Inv(y))) % MOD;
    }
    printf("%d", Ans);
}

DP 优化

思路

如果没有接触过可能很难想到这个与 DP 的联系。事实上,我们可以将某一维的 DP 看作一个函数,即令 f i ( j ) = d p [ i ] [ j ] f_i(j) = dp[i][j] (注意这个 f i ( j ) f_i(j) 与上文中的“子函数”没有关系)那么,如果我们要求的 d p [ i ] [ j ] dp[i][j] 中的 j j 值很大(例如 j = 1 0 9 j = 10^9 ),我们就可以只计算 d p [ i ] [ 1 ] , d p [ i ] [ 2 ] , , d p [ i ] [ p + 1 ] dp[i][1], dp[i][2], \cdots, dp[i][p + 1] p p f i ( x ) f_i(x) 的次数),并用点 ( 1 , d p [ i ] [ 1 ] ) (1, dp[i][1]) ( 2 , d p [ i ] [ 2 ] ) (2, dp[i][2]) ,…, ( p + 1 , d p [ i ] [ p + 1 ] ) (p + 1, dp[i][p + 1]) 确定多项式 f i ( x ) f_i(x) ,并快速求得 f i ( j ) f_i(j) ,即 d p [ i ] [ j ] dp[i][j] ,时间复杂度为 O ( p 2 ) O(p^2)

这类优化的难点在于要准确地计算 p p 的值,即 f i ( x ) f_i(x) 的次数,接下来通过例题讲解如何计算 p p

例题一

洛谷 P4463 [集训队互测2012] calc

分析

发现我们只需要计算所有递增的合法序列的值之和,然后乘上 n ! n! 即为答案,因为每种递增的合法序列任意打乱顺序仍然是合法的,并且原先就不同,打乱后也一定不同。

d p [ i ] [ j ] dp[i][j] 表示:长度 i i 所含元素值不超过 j j 递增的合法序列的值之和,考虑在第 i i 个位置放元素 j j 还是放其他小于 j j 的元素,本质即为一个背包问题,则 d p [ i ] [ j ] = j d p [ i 1 ] [ j 1 ] + d p [ i ] [ j 1 ] dp[i][j] = j \cdot dp[i - 1][j - 1] + dp[i][j - 1] 答案为 d p [ n ] [ k ] dp[n][k] ,然后发现 k 1 0 9 k \leq 10^9 ,不可能直接 DP。

按照上文中的方法,我们令 f n ( i ) = d p [ n ] [ i ] f_n(i) = dp[n][i] ,所求的就是 f n ( k ) f_n(k) 。接下来求出多项式 f n ( x ) f_n(x) 的次数 p p ,然后我们就只需要 DP 出 d p [ n ] [ 1 ] dp[n][1] d p [ n ] [ p + 1 ] dp[n][p + 1] ,再用拉格朗日插值法就能算出 f n ( k ) f_n(k) 了。

接下来推导 f n ( x ) f_n(x) 的次数,令 g ( n ) g(n) 表示多项式 f n ( x ) f_n(x) 的次数:
d p [ i ] [ j ] = j d p [ i 1 ] [ j 1 ] + d p [ i ] [ j 1 ] f i ( j ) = j f i 1 ( j 1 ) + f i ( j 1 ) f i ( j ) f i ( j 1 ) = j f i 1 ( j 1 ) \begin{aligned} dp[i][j] &= j \cdot dp[i - 1][j - 1] + dp[i][j - 1] \\ f_i(j) &= j \cdot f_{i - 1}(j - 1) + f_i(j - 1) \\ f_i(j) - f_i(j - 1) &= j \cdot f_{i - 1}(j - 1) \end{aligned} f i ( x ) = i = 0 g ( n ) a i x i f_i(x) = \sum\limits_{i = 0}^{g(n)} a_i x ^i ,将 j j j 1 j - 1 暴力代入 f i ( j ) f i ( j 1 ) f_i(j) - f_i(j - 1) 这个式子,发现 a g ( i ) j g ( n ) a_{g(i)} j^{g(n)} 这个最高次项被消掉了(代入后有关最高次项的部分仅为 a g ( i ) j g ( i ) a g ( i ) ( j 1 ) g ( i ) a_{g(i)} j^{g(i)} - a_{g(i)} (j - 1)^{g(i)} )!

于是得到 f i ( j ) f i ( j 1 ) f_i(j) - f_i(j - 1) 的次数为 g ( i ) 1 g(i) - 1 ,又因为 j f i 1 ( j 1 ) j \cdot f_{i - 1}(j - 1) 的次数为 g ( i 1 ) + 1 g(i - 1) + 1 ,所以
g ( i ) 1 = g ( i 1 ) + 1 g ( i ) = g ( i 1 ) + 2 \begin{aligned} g(i) - 1 &= g(i - 1) + 1\\ g(i) &= g(i - 1) + 2 \end{aligned} 又因为 g ( 0 ) = 0 g(0) = 0 f 0 ( x ) = d p [ 0 ] [ x ] = 1 f_0(x) = dp[0][x] = 1 )所以 g ( n ) = 2 n g(n) = 2n ,证得 f n ( x ) f_n(x) 的次数为 2 n 2n

然后我们只需要用朴素的 DP 求得 d p [ n ] [ 1 ] dp[n][1] d p [ n ] [ 2 ] dp[n][2] ,…, d p [ n ] [ 2 n + 1 ] dp[n][2n + 1] (注意点数要求比次数多一才能得到正确的多项式),并用拉格朗日插值法求得 d p [ n ] [ k ] dp[n][k] 即可。

代码

#include <bits/stdc++.h>

const int MAXN = 500;

int N, K, P;
int Dp[MAXN + 5][2 * MAXN + 1 + 5];

int Add(int a, const int &b) {
    a += b; return (a >= P) ? (a - P) : a;
}

int Mul(const int &a, const int &b) {
    return (long long)a * b % P;
}

int Inv(int x) {
    int y = P - 2, ret = 1;
    while (y) {
        if (y & 1)
            ret = Mul(ret, x);
        x = Mul(x, x);
        y >>= 1;
    }
    return ret;
}

int main() {
    scanf("%d%d%d", &K, &N, &P);
    int M = 2 * N + 1;
    for (int i = 0; i <= M; i++)
        Dp[0][i] = 1;
    for (int i = 1; i <= N; i++)
        for (int j = i; j <= M; j++)
            Dp[i][j] = Add(Dp[i][j - 1], Mul(Dp[i - 1][j - 1], j));
    int Ans = 0, Fac = 1;
    for (int i = 1; i <= N; i++)
        Fac = Mul(Fac, i);
    for (int i = 1; i <= M; i++) {
        int x = Dp[N][i], y = 1;
        for (int j = 1; j <= M; j++)
            if (i != j) {
                x = Mul(x, (K >= j) ? (K - j) : (K - j + P));
                y = Mul(y, (i >= j) ? (i - j) : (i - j + P));
            }
        Ans = Add(Ans, Mul(x, Inv(y)));
    }
    printf("%d", Mul(Ans, Fac));
    return 0;
}

例题二

CF995F Cowmpany Cowmpensation

题意:给定整数 n n D D 1 n 3000 1 \leq n \leq 3000 1 D 1 0 9 1 \leq D \leq 10^9 )以及一个 n n 个结点的树,要求给每个结点分配一个 [ 1 , D ] [1, D] 之间的整数作为权值,并且满足父亲结点权值大于等于儿子结点,求方案总数。

分析

d p [ u ] [ i ] dp[u][i] 表示:以 u u 为根的子树中,每个结点的权值都在 [ 1 , i ] [1,i] 内的方案数,同样是一个背包
d p [ u ] [ i ] = d p [ u ] [ i 1 ] + v  is a son of  u d p [ v ] [ i 1 ] dp[u][i] = dp[u][i - 1] + \sum_{v \text{ is a son of } u} dp[v][i - 1] g ( n ) g(n) 定义与上题类似,然后得到
g ( u ) 1 = v  is a son of  u g ( v ) g ( u ) = v  is a son of  u g ( v ) + 1 \begin{aligned} g(u) - 1 &= \sum_{v \text{ is a son of } u} g(v)\\ g(u) &= \sum_{v \text{ is a son of } u} g(v) + 1 \end{aligned} 注意边界 g ( v ) = [ v  is a leaf  ] g(v) = [v \text{ is a leaf }] ,因为对于一个叶子 u u d p [ u ] [ i ] = i dp[u][i] = i 。因此这就是一个子树大小的 DP 式,于是 g ( 1 ) = n g(1) = n ,暴力算得 d p [ 1 ] [ 1 ] dp[1][1] d p [ 1 ] [ 2 ] dp[1][2] ,…, d p [ 1 ] [ n + 1 ] dp[1][n + 1] ,再拉格朗日即可。

代码

#include <bits/stdc++.h>

const int MAXN = 3000;
const int MOD = 1000000007;

int N, D, M;
std::vector<int> G[MAXN + 5];

int Dp[MAXN + 5][MAXN + 5];

int Add(int a, const int &b) {
    a += b; return (a >= MOD) ? (a - MOD) : a;
}

int Mul(const int &a, const int &b) {
    return (long long)a * b % MOD;
}

int Inv(int x) {
    int y = MOD - 2, ret = 1;
    while (y) {
        if (y & 1)
            ret = Mul(ret, x);
        x = Mul(x, x);
        y >>= 1;
    }
    return ret;
}

void Dfs(int u) {
    for (int v: G[u])
        Dfs(v);
    for (int i = 1; i <= M; i++) {
        int tmp = 1;
        for (int v: G[u])
            tmp = Mul(tmp, Dp[v][i]);
        Dp[u][i] = Add(Dp[u][i - 1], tmp);
    }
}

int main() {
    scanf("%d%d", &N, &D);
    for (int i = 2; i <= N; i++) {
        int p; scanf("%d", &p);
        G[p].push_back(i);
    }
    M = N + 1;
    Dfs(1);
    int Ans = 0;
    for (int i = 1; i <= M; i++) {
        int x = Dp[1][i], y = 1;
        for (int j = 1; j <= M; j++)
            if (i != j) {
                x = Mul(x, (D >= j) ? (D - j) : (D - j + MOD));
                y = Mul(y, (i >= j) ? (i - j) : (i - j + MOD));
            }
        Ans = Add(Ans, Mul(x, Inv(y)));
    }
    printf("%d", Ans);
    return 0;
}

A trick

上面两题的“点”的横坐标有个规律:是连续的 p + 1 p + 1 个正整数。结合拉格朗日插值法的分子分母的特征,发现可以用前缀积和后缀积优化拉格朗日插值法的内层循环代码,使时间复杂度由 O ( p 2 ) O(p^2) 优化为 O ( p ) O(p) ,但是复杂度的瓶颈在于开头的朴素 DP,所以没有提这个方法。

猜你喜欢

转载自blog.csdn.net/C20190102/article/details/106693455