题目链接:http://acm.fzu.edu.cn/problem.php?pid=2302
题目大意:给一个有n个结点的环,每个结点都有权值。现在要将这个环分为k段,每一段的价值为权值和的平方,整个环的价值为各段的价值之和。现在要求出如何分才能使整个环的价值最小。
题目思路:由于题目给出的是环,我们可以考虑枚举断点,先把环变成一条链,然后再进行dp求解。
我们现在令表示长度为 i 的链分成 j 段时最小的价值为多少,表示前 i 个结点的权值之和。这样我们就可以得到如下的状态转移方程:
但是这个状态转移方程的复杂度是的,再加上枚举断点的复杂度,整体的复杂度为。
这个复杂度显然是不合理的,我们现在考虑用斜率优化来降低复杂度。
解设对于 ,如果取 k 是优于 p 的话,必然满足如下条件:
,这个式子可以转化为:
我们再令,。
再令。
所以当时,取 k 是要优于p 的。
现在如果有,同时满足,那么 k 以后的点对于答案就无法造成更优的影响。这样我们就可以移除一些对答案无法造成更优影响的点,来降低复杂度。
接着就利用这个式子和单调队列来维护斜率即可
1、构造一个单调队列来维护前 i 个数对后续解的影响。
2、对第 i 位求解时,如果队列中存在a[1],a[2],a[3]三个元素,当时,就说明选择a[2] 是要优于a[1]的,我们就将a1出队,直到满足时,就可以对dp值进行更新;
3、对于第 i 位的ci入队时,假设队列中从头到尾已经有元素a[1],a[2],a[2]。那么当d要入队的时候,我们维护队列的上凸性质,即如果,那么就将a[3]点删除。直到找到为止,并将d点加入在该位置中。
整体的时间复杂度就降为的了。
具体实现看代码:
#include <cstdio>
#include <cstring>
#include <queue>
#include <iostream>
#include <vector>
#include <algorithm>
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pb push_back
#define MP make_pair
#define lowbit(x) x&-x
#define clr(a) memset(a,0,sizeof(a))
#define _INF(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define IOS ios::sync_with_stdio(false)
#define debug(x) cout<<"["<<x<<"]"<<endl
using namespace std;
typedef long long ll;
typedef pair<int, int>pii;
typedef pair<ll, ll>pll;
const int maxn = 200 + 7;
const int inf = 0x3f3f3f3f;
int n, k, T;
int a[maxn], tmp[maxn], head, tail;
int sum[maxn], q[maxn], dp[maxn][maxn];
int get_up(int i, int j, int k) {
return dp[i][j] + sum[j] * sum[j] - (dp[i][k] + sum[k] * sum[k]);
}
int get_down(int j, int k) {
return 2 * (sum[j] - sum[k]);
}
int get_dp(int i, int j, int k) {
return dp[i][k] + (sum[j] - sum[k]) * (sum[j] - sum[k]);
}
int main() {
//freopen("in.txt", "r", stdin);
cin >> T;
while (T--) {
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
int ans = inf;
for (int f = 1; f <= n; f++) {
for (int j = 1; j <= n; j++) {
int cnt = f + j;
if (cnt > n) cnt -= n;
tmp[j] = a[cnt];
sum[j] = sum[j - 1] + tmp[j];
}
memset(dp, 0x3f, sizeof(dp));
dp[0][0] = 0;
for (int j = 1; j <= k; j++) {
head = tail = 0;
q[tail++] = 0;
for (int i = 1; i <= n; i++) {
while (head + 1 < tail && get_up(j - 1, q[head + 1], q[head]) < sum[i] * get_down(q[head + 1], q[head])) head++;
dp[j][i] = min(dp[j][i], get_dp(j - 1, i, q[head]));
while (head + 1 < tail && get_up(j - 1, i, q[tail - 1]) * get_down(q[tail - 1], q[tail - 2]) <= get_up(j - 1, q[tail - 1], q[tail - 2])*get_down(i, q[tail - 1])) tail--;
q[tail++] = i;
}
}
ans = min(ans, dp[k][n]);
}
cout << ans << endl;
}
return 0;
}