CodeForces - 1307D Cow and Fields(最短路+思维)

题目链接:点击查看

题目大意:给出一个由 n 个点和 m 条边组成的无向图,其中有 k 个点被标记了,题目要求选出两个被标记的点,连接一条边,使得从点 1 到点 n 的最短路最大

题目分析:读完题后,大部分同学应该都会在脑中浮现出一个 n * n 的做法吧,那就是先用bfs求出 dis[ i ][ 0 ] 和 dis[ i ][ 1 ] ,分别表示从点 1 到点 i 的距离和从点 n 到点 i 的距离,然后两层循环枚举被标记的点,计算出 dis[ i ][ 0 ] + dis[ j ][ 1 ] + 1 的最大值就是答案了,思路确实没有问题,现在的问题是如何优化

先来考虑一个比较简单的事情,假如现在有两个点 i 和 j ,如果其建边的话,最短路可能是 1 -> i -> j -> n 或者 1 -> j -> i -> n,这样代表的距离也就是 dis[ i ][ 0 ] + dis[ j ][ 1 ] + 1 和 dis[ i ][ 1 ] + dis[ j ][ 0 ] + 1 了,因为题目要求的是最短路,所以很显然会选择更小的那一个,换句话说,当满足 dis[ i ][ 0 ] + dis[ j ][ 1 ] + 1 < dis[ i ][ 1 ] + dis[ j ][ 0 ] + 1 时,我们会选择前者,通过移项以及约分,不难化简到:dis[ i ][ 0 ] - dis[ i ][ 1 ] < dis[ j ][ 0 ] - dis[ j ][ 1 ],这个公式又代表什么呢?也就是说,当我们现在有两个被标记的点时,dis[ i ][ 0 ] - dis[ i ][ 1 ] 更小的这个点会放在前面,而另外一个点会放在后面

得出这个结论后,我们就可以先对 k 个点按照 dis[ i ][ 0 ] - dis[ i ][ 1 ] 从小到大的顺序排序,因为我们需要维护 dis[ i ][ 0 ] + dis[ j ][ 1 ] + 1 的最大值,我们暂且规定点 i 在点 j 的前面,也就是如果点 i 和点 j 建边的话,一定满足最短路是 1 -> i -> j -> n 的,这样我们O(n)枚举一遍位于后面的点 j ,然后找到点 j 前面的 dis[ i ][ 0 ] 的最大值,也就是前缀的最大值,这样可以保证相加之和是最大的,最后记得特判一下上限就好了,上限是原图的最短路

代码:

#include<iostream>
#include<cstdio>
#include<string>
#include<ctime>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<stack>
#include<climits>
#include<queue>
#include<map>
#include<set>
#include<sstream>
#include<unordered_map>
using namespace std;
     
typedef long long LL;
    
typedef unsigned long long ull;
     
const int inf=0x3f3f3f3f;

const int N=2e5+100;

vector<int>node[N];

int id[N],dis[N][2];

void bfs(int st,int pos)
{
	queue<int>q;
	q.push(st);
	dis[st][pos]=0;
	while(q.size())
	{
		int u=q.front();
		q.pop();
		for(auto v:node[u])
			if(dis[v][pos]>dis[u][pos]+1)
			{
				dis[v][pos]=dis[u][pos]+1;
				q.push(v);
			}
	}
}

bool cmp(int a,int b)
{
	return dis[a][0]-dis[a][1]<dis[b][0]-dis[b][1];
}
 
int main()
{
//#ifndef ONLINE_JUDGE
//  freopen("input.txt","r",stdin);
//    freopen("output.txt","w",stdout);
//#endif
//  ios::sync_with_stdio(false);
	memset(dis,inf,sizeof(dis));
	int n,m,k;
	scanf("%d%d%d",&n,&m,&k);
	for(int i=1;i<=k;i++)
		scanf("%d",id+i);
	while(m--)
	{
		int u,v;
		scanf("%d%d",&u,&v);
		node[u].push_back(v);
		node[v].push_back(u);
	}
	bfs(1,0);
	bfs(n,1);
	sort(id+1,id+1+k,cmp);
	int ans=0,mmax=dis[id[1]][0];
	for(int i=2;i<=k;i++)
	{
		ans=max(ans,mmax+dis[id[i]][1]+1);
		mmax=max(mmax,dis[id[i]][0]);
	}
	printf("%d\n",min(dis[n][0],ans));
	
	
	
	
	
	
	
	
	
	
	
	
	
	
    return 0;
}
发布了646 篇原创文章 · 获赞 20 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/qq_45458915/article/details/104377505
今日推荐