wqs二分入门题,这题可以直接 dp,且复杂度不会比 wqs 二分高。
问题具有一个性质:限制你建 m 个 post office,但是显然建的越多代价越少,代价和建的 post office 的数量在二维平面上是一个上凸包的形状。
考虑在给每个 post office 加一个代价 x x x ,每建一个 post office 都要付出额外 x x x 的代价,显然 x x x 越大,最优解建的 post office 越少,反之越多,即具有单调性。
可以二分 x x x,然后在不考虑建几个 post office 的限制下进行 dp,求出最小代价 以及对应的建立的 post office 的数量。
考虑如何 dp:首先考虑,在 [l,r] 区间内建一座 post office,建在哪最优。
画一下可以发现是建在中间的点 m i d mid mid 最优,设刚开始建在 x x x( x < m i d x < mid x<mid), x x x 的左边有 a 个点,右边有 b 个点,且 d i s ( x , x + 1 ) = d dis(x, x + 1) = d dis(x,x+1)=d,如果将post ofiice 右移到下一个点,对答案的贡献改变了: ( a + 1 ) ∗ d − b ∗ d (a + 1) * d - b * d (a+1)∗d−b∗d,当 a + 1 ≤ b a + 1 \leq b a+1≤b 时都可以移动 x x x,当 x > m i d x > mid x>mid 时可以类似的方法证明。
容易列出转移方程: d p [ i ] = d p [ j ] + w ( j , i ) dp[i] = dp[j] + w(j,i) dp[i]=dp[j]+w(j,i),其中 w ( j , i ) w(j,i) w(j,i) 表示在 [j,i] 建一个 post office 的最小代价。
预处理 w ( j , i ) w(j,i) w(j,i):可以发现如果 w ( j , i − 1 ) w(j,i - 1) w(j,i−1) 已经求出, w ( j , i ) = w ( j , i − 1 ) + a [ j ] − a [ ⌊ i + j 2 ⌋ ] w(j,i) = w(j,i - 1) + a[j] - a[\lfloor\frac{i+j}{2}\rfloor] w(j,i)=w(j,i−1)+a[j]−a[⌊2i+j⌋]
复杂度为 n 2 log v n^2\log v n2logv, v v v 较大,可以取所有 a a a 的和。
代码:
#include<iostream>
#include<string.h>
#include<stdio.h>
#include<algorithm>
using namespace std;
const int maxn = 1e3 + 10;
const int inf = 0x3f3f3f3f;
int n,m,a[maxn],sum;
int dp[maxn],w[maxn][maxn],d[maxn];
int solve(int x) {
memset(dp,inf,sizeof dp);
memset(d,0,sizeof d);
dp[0] = 0; d[0] = 0; //在取得最优的情况下尽可能的建post office
for (int i = 1; i <= n; i++) {
for (int j = 0; j < i; j++) {
if (dp[j] + w[j + 1][i] + x < dp[i]) {
dp[i] = dp[j] + w[j + 1][i] + x;
d[i] = d[j] + 1;
} else if (dp[j] + w[j + 1][i] + x == dp[i]) {
if (d[j] + 1 > d[i])
d[i] = d[j] + 1;
}
}
}
return d[n];
}
int main() {
scanf("%d%d",&n,&m);
for (int i = 1; i <= n; i++) {
scanf("%d",&a[i]);
sum += a[i];
}
sort(a + 1,a + n + 1);
for (int i = 1; i <= n; i++)
for (int j = i + 1; j <= n; j++)
w[i][j] = w[i][j - 1] + a[j] - a[i + j >> 1];
memset(dp,0,sizeof dp);
memset(d,0,sizeof d);
int l = 0, r = sum;
while (l < r) {
int mid = l + r >> 1;
if (solve(mid) < m) r = mid;
else l = mid + 1;
}
solve(l - 1);
printf("%d\n",dp[n] - m * (l - 1));
return 0;
}