bzoj 4609 [Wf2016]Branch Assignment(dp 凸优化(wqs二分)| 决策单调性优化)

在这里插入图片描述


实际上给的单向边而不是双向边,先处理出每个点到 b + 1 b + 1 b+1 的最短距离 d a d_a da b + 1 b + 1 b+1 到每个点的最短距离 d b d_b db。 预处理权值 s = d a + d b s = d_a + d_b s=da+db,集合划分后的每一个点的贡献为: s [ i ] ∗ ( s i z e − 1 ) s[i] * (size - 1) s[i](size1),其中 s i z e size size 为划分到的集合的大小。 显然最优情况是值连续的划分到一组,对 s s s 进行排序,就可以在序列上做线性 dp。

考虑 d p [ i ] [ j ] dp[i][j] dp[i][j] 表示前 j j j 个分成 i i i 组的最小总距离。转移方程为 dp[k][i] = dp[k - 1][j] + (i - j - 1) * (sum[i] - sum[j]) ,总复杂度为 O ( n 3 ) O(n^3) O(n3),考虑优化:

1、不考虑选择分 k 组这个限制,肯定尽可能多分组最优,给每个分组加上一个代价 x x x,每分一次组都要加上 x x x 的代价,显然 x x x 越大,最优解的分组数越少, x x x 越小, 最优解的分组数越大,满足单调性,考虑二分这个代价 x x x,然后做没有限制的 dp 的复杂度是 O ( n 2 ) O(n^2) O(n2),二分的右边界要大一点,大到最优解可能只分一次组。(从凸包的角度考虑可能不容易看出)

2、由于权值比较小的点划分到的 s i z e size size 肯定更大,决策具有单调性,利用这个单调性,当计算 dp[k][i] 时,转移范围只要枚举 [ i − ⌊ i k ⌋ , i − 1 ] [i - \lfloor\frac{i}{k}\rfloor, i - 1] [iki,i1],因为最后这一块的大小肯定小于等于平均值, k k k 次计算后,对于每一个 i i i,计算所有的 d p [ k ] [ i ] dp[k][i] dp[k][i] 的复杂度是 i log ⁡ i i \log i ilogi,最后总复杂度为 n 2 log ⁡ n n^2\log n n2logn,这个 n 2 log ⁡ n n^2\log n n2logn 没有跑满,因此跑得比较快。

由于决策转移点具有单调性,还可以实现到 O ( n 2 ) O(n^2) O(n2),且已经有 n log ⁡ 2 n n\log^2n nlog2n 的做法 (都不会)


wqs二分优化代码:

#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e3 + 10;
#define pii pair<int,int>
#define fir first
#define sec second
typedef long long ll;
const ll inf = 1e15;
int n,b,s,r,a[maxn],vis[maxn];
ll sum[maxn],d[maxn],t[maxn];
vector<pii> g[maxn],h[maxn];
ll dp[maxn],tp[maxn],lst[maxn];
void spfa1(int s) {
    
    
	queue<int> q;
	for (int i = 1; i <= n; i++)
		d[i] = inf;
	memset(vis,0,sizeof vis);
	d[s] = 0;
	q.push(s);
	while (!q.empty()) {
    
    
		int top = q.front();
		q.pop();
		vis[top] = 0;
		for (auto it : g[top]) {
    
    
			if (d[it.fir] > d[top] + it.sec) {
    
    
				d[it.fir] = d[top] + it.sec;
				if (!vis[it.fir]) {
    
    
					q.push(it.fir);
					vis[it.fir] = 1;
				}
			} 
		}
	}
}
void spfa2(int s) {
    
    
	queue<int> q;
	for (int i = 1; i <= n; i++)
		t[i] = inf;
	memset(vis,0,sizeof vis);
	t[s] = 0;
	q.push(s);
	while (!q.empty()) {
    
    
		int top = q.front();
		q.pop();
		vis[top] = 0;
		for (auto it : h[top]) {
    
    
			if (t[it.fir] > t[top] + it.sec) {
    
    
				t[it.fir] = t[top] + it.sec;
				if (!vis[it.fir]) {
    
    
					q.push(it.fir);
					vis[it.fir] = 1;
				}
			} 
		}
	}
}
ll solve(ll x) {
    
    
	for (int i = 0; i <= b; i++)
		dp[i] = inf, lst[i] = tp[i] = 0;
	dp[0] = 0;
	for (int i = 1; i <= b; i++) {
    
    
		for (int j = lst[i]; j < i; j++) {
    
    
			if (dp[j] + (i - j - 1) * (sum[i] - sum[j]) + x < dp[i]) {
    
    
				dp[i] = dp[j] + (i - j - 1) * (sum[i] - sum[j]) + x;
				tp[i] = tp[j] + 1;	
			} else if (dp[j] + (i - j - 1) * (sum[i] - sum[j]) + x == dp[i]) {
    
    
				if (tp[i] < tp[j] + 1)
					tp[i] = tp[j] + 1;
			}
		}
	}
	return tp[b];
}
int main() {
    
    
	scanf("%d%d%d%d",&n,&b,&s,&r);
	for (int i = 1; i <= r; i++) {
    
    
		int u,v,w; scanf("%d%d%d",&u,&v,&w);
		g[u].push_back(pii(v,w));
		h[v].push_back(pii(u,w));
	}
	spfa1(b + 1); spfa2(b + 1);
	for (int i = 1; i <= b; i++)
		sum[i] = d[i] + t[i];
	sort(sum + 1,sum + b + 1);
	for (int i = 1; i <= b; i++)
		sum[i] += sum[i - 1];
	ll l = 0, r = 1ll << 48;
	while (l < r) {
    
    
		ll mid = l + r >> 1;
		if (solve(mid) < s) r = mid;
		else l = mid + 1;
	}
	solve(l - 1);
	printf("%lld\n",dp[b] - s * (l - 1));
	return 0;
}

决策单调性优化:

#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e3 + 10;
#define pii pair<int,int>
#define fir first
#define sec second
typedef long long ll;
const ll inf = 1e15;
int n,b,s,r,a[maxn],vis[maxn];
ll sum[maxn],d[maxn],t[maxn];
vector<pii> g[maxn],h[maxn];
ll dp[maxn],tp[maxn];
void spfa1(int s) {
    
    
	queue<int> q;
	for (int i = 1; i <= n; i++)
		d[i] = inf;
	memset(vis,0,sizeof vis);
	d[s] = 0;
	q.push(s);
	while (!q.empty()) {
    
    
		int top = q.front();
		q.pop();
		vis[top] = 0;
		for (auto it : g[top]) {
    
    
			if (d[it.fir] > d[top] + it.sec) {
    
    
				d[it.fir] = d[top] + it.sec;
				if (!vis[it.fir]) {
    
    
					q.push(it.fir);
					vis[it.fir] = 1;
				}
			} 
		}
	}
}
void spfa2(int s) {
    
    
	queue<int> q;
	for (int i = 1; i <= n; i++)
		t[i] = inf;
	memset(vis,0,sizeof vis);
	t[s] = 0;
	q.push(s);
	while (!q.empty()) {
    
    
		int top = q.front();
		q.pop();
		vis[top] = 0;
		for (auto it : h[top]) {
    
    
			if (t[it.fir] > t[top] + it.sec) {
    
    
				t[it.fir] = t[top] + it.sec;
				if (!vis[it.fir]) {
    
    
					q.push(it.fir);
					vis[it.fir] = 1;
				}
			} 
		}
	}
}
ll solve() {
    
    
	for (int i = 0; i <= b; i++)
		tp[i] = dp[i] = inf;
	tp[0] = 0;
	for (int k = 1; k <= s; k++) {
    
    
		for (int i = 1; i <= b; i++) {
    
    
			for (int j = i - i / k; j <= i - 1; j++)		// i / k 是平均每个块的大小 
				dp[i] = min(dp[i],tp[j] + (i - j - 1) * (sum[i] - sum[j]));
		}
		for (int i = 0; i <= b; i++)
			tp[i] = dp[i], dp[i] = inf;
	}
	return tp[b]; 
}
int main() {
    
    
	scanf("%d%d%d%d",&n,&b,&s,&r);
	for (int i = 1; i <= r; i++) {
    
    
		int u,v,w; scanf("%d%d%d",&u,&v,&w);
		g[u].push_back(pii(v,w));
		h[v].push_back(pii(u,w));
	}
	spfa1(b + 1); spfa2(b + 1);
	for (int i = 1; i <= b; i++)
		sum[i] = d[i] + t[i];
	sort(sum + 1,sum + b + 1);
	for (int i = 1; i <= b; i++)
		sum[i] += sum[i - 1];
	printf("%lld\n",solve());
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_41997978/article/details/104931386
今日推荐