题目链接:点击查看
题目大意:给出一个由 n 个点和 m 条边组成的无向图,其中有 k 个点被标记了,题目要求选出两个被标记的点,连接一条边,使得从点 1 到点 n 的最短路最大
题目分析:读完题后,大部分同学应该都会在脑中浮现出一个 n * n 的做法吧,那就是先用bfs求出 dis[ i ][ 0 ] 和 dis[ i ][ 1 ] ,分别表示从点 1 到点 i 的距离和从点 n 到点 i 的距离,然后两层循环枚举被标记的点,计算出 dis[ i ][ 0 ] + dis[ j ][ 1 ] + 1 的最大值就是答案了,思路确实没有问题,现在的问题是如何优化
先来考虑一个比较简单的事情,假如现在有两个点 i 和 j ,如果其建边的话,最短路可能是 1 -> i -> j -> n 或者 1 -> j -> i -> n,这样代表的距离也就是 dis[ i ][ 0 ] + dis[ j ][ 1 ] + 1 和 dis[ i ][ 1 ] + dis[ j ][ 0 ] + 1 了,因为题目要求的是最短路,所以很显然会选择更小的那一个,换句话说,当满足 dis[ i ][ 0 ] + dis[ j ][ 1 ] + 1 < dis[ i ][ 1 ] + dis[ j ][ 0 ] + 1 时,我们会选择前者,通过移项以及约分,不难化简到:dis[ i ][ 0 ] - dis[ i ][ 1 ] < dis[ j ][ 0 ] - dis[ j ][ 1 ],这个公式又代表什么呢?也就是说,当我们现在有两个被标记的点时,dis[ i ][ 0 ] - dis[ i ][ 1 ] 更小的这个点会放在前面,而另外一个点会放在后面
得出这个结论后,我们就可以先对 k 个点按照 dis[ i ][ 0 ] - dis[ i ][ 1 ] 从小到大的顺序排序,因为我们需要维护 dis[ i ][ 0 ] + dis[ j ][ 1 ] + 1 的最大值,我们暂且规定点 i 在点 j 的前面,也就是如果点 i 和点 j 建边的话,一定满足最短路是 1 -> i -> j -> n 的,这样我们O(n)枚举一遍位于后面的点 j ,然后找到点 j 前面的 dis[ i ][ 0 ] 的最大值,也就是前缀的最大值,这样可以保证相加之和是最大的,最后记得特判一下上限就好了,上限是原图的最短路
代码:
#include<iostream>
#include<cstdio>
#include<string>
#include<ctime>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<stack>
#include<climits>
#include<queue>
#include<map>
#include<set>
#include<sstream>
#include<unordered_map>
using namespace std;
typedef long long LL;
typedef unsigned long long ull;
const int inf=0x3f3f3f3f;
const int N=2e5+100;
vector<int>node[N];
int id[N],dis[N][2];
void bfs(int st,int pos)
{
queue<int>q;
q.push(st);
dis[st][pos]=0;
while(q.size())
{
int u=q.front();
q.pop();
for(auto v:node[u])
if(dis[v][pos]>dis[u][pos]+1)
{
dis[v][pos]=dis[u][pos]+1;
q.push(v);
}
}
}
bool cmp(int a,int b)
{
return dis[a][0]-dis[a][1]<dis[b][0]-dis[b][1];
}
int main()
{
//#ifndef ONLINE_JUDGE
// freopen("input.txt","r",stdin);
// freopen("output.txt","w",stdout);
//#endif
// ios::sync_with_stdio(false);
memset(dis,inf,sizeof(dis));
int n,m,k;
scanf("%d%d%d",&n,&m,&k);
for(int i=1;i<=k;i++)
scanf("%d",id+i);
while(m--)
{
int u,v;
scanf("%d%d",&u,&v);
node[u].push_back(v);
node[v].push_back(u);
}
bfs(1,0);
bfs(n,1);
sort(id+1,id+1+k,cmp);
int ans=0,mmax=dis[id[1]][0];
for(int i=2;i<=k;i++)
{
ans=max(ans,mmax+dis[id[i]][1]+1);
mmax=max(mmax,dis[id[i]][0]);
}
printf("%d\n",min(dis[n][0],ans));
return 0;
}