DP玄学优化——斜率优化

——以此博客来悼念我在\(QBXT\)懵逼的时光


\(rqy\; tql\) (日常%\(rqy\)


概念及用途

斜率优化是\(DP\)的一种较为常用的优化(据说在高中课本里稍有提及),它可以用于优化这样的一种\(DP\)式子
\[dp[i]=a[i]+\max(y_j-k_ix_j)\;\;\; j\in[1,i-1]\]

原理

以下均以上面的\(DP\)方程为例

如果我们将上式中的\((x_j,y_j)\)画到坐标系里,然后画一条斜率为\(k_i\)的直线,则这条直线为的方程为\(y=k_ix+(y_j-k_ix_j)\),所以\(y_j-k_ix_j\)这一部分就是该直线的截距。所以在\(k_i\)确定的情况下,我们要求一个最大值,即该直线的斜率最大,那就相当于要找直线从上向下平移,所碰到的第一个点,这时,它的截距是最大的。

如果我们把这样的点筛出来,就能大大优化时间复杂度。
gif示意图

最后筛出来的点就是图中最后出现的红色点

当两个点在候选队列里,此时再加进来一个点,如果是这样的一种情况:

那我们可以把中间的那个点删去,因为删去它不会影响我们的最优解,他的截距显然要更大一些

如果是这样的一种情况

则三个点暂时全部保留,因为在某些时刻,它们都可能成为最优解

实现方式

我们可以用双端队列来处理,删掉不优的,加进去暂时较优的。在这只发一下代码中比较重要的部分吧

bool check(int a,int b,int c){
    return (y(a)-y(c))*(x(b)-x(c))-(y(b)-y(c))*(x(a)-x(c))>=0;
}
/*
省略中间部分
*/
for(int i=1;i<=n;i++){
        while(h<t&&y(q[h])-k(i)*x(q[h])<y(q[h+1])-k(i)*x(q[h+1])) h++;
        dp[i]=(sum[i]-l)*(sum[i]-l)+y(q[h])-k(i)*x(q[h]);
        while(h<t&&check(q[t-1],q[t],i)) t--;
        q[++t]=i;
    }

代码中的\(x\)函数、\(y\)函数、\(k\)函数、\(check\)函数请按需要更改,模板差不多就是这样了

例题

HNOI2008 玩具装箱TOY

若果我们在输入时将所有的\(C_i\)\(L+1\),原公式可以化为\((S-L)^2\)

\(f_i\)表示前\(i\)个数的最小花费
\(f_i=min(f_j+(S_i-S_j-L)^2)\;\; 1 {<=}j{<=}i-1\)
\(f_i=min(f_j+(S_i-L)^2-2(S_i-L)S_j+S_j^2)\)
\(\;\;\;\;=(S_i-L)^2+min(y_j-k_ix_j)\)
\(y_j=f_j+S_j^2,k_i=2(S_i-L),x_j=S_j\)

所以我们就可用斜率优化来\(A\)掉这道题了

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define ll long long
int read(){
    int k=0,f=1; char c=getchar();
    for(;c<'0'||c>'9';c=getchar())
      if(c=='-') f=-1;
    for(;c>='0'&&c<='9';c=getchar())
      k=(k<<3)+(k<<1)+c-48;
    return k*f;
}
int n;
ll dp[50010],sum[50010],l;
inline ll x(int i){ return sum[i]; }
inline ll y(int i){ return dp[i]+sum[i]*sum[i]; }
inline ll k(int i){ return 2LL*(sum[i]-l); }
inline bool check(int a,int b,int c){
    return (y(a)-y(c))*(x(b)-x(c))-(y(b)-y(c))*(x(a)-x(c))>=0;
}
ll q[100010],h=1,t=1;
int main(){
    n=read(),l=read()+1;
    for(int i=1;i<=n;i++) sum[i]=sum[i-1]+read()+1;
    for(int i=1;i<=n;i++){
        while(h<t&&y(q[h])-k(i)*x(q[h])>=y(q[h+1])-k(i)*x(q[h+1])) h++;
        dp[i]=(sum[i]-l)*(sum[i]-l)+y(q[h])-k(i)*x(q[h]);
        while(h<t&&check(q[t-1],q[t],i)) t--;
        q[++t]=i;
    }
    cout<<dp[n];
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/wxl-Ezio/p/9420890.html