BZOJ1997——次小生成树(严格次小生成树)

传送门

次小生成树什么的就不想讲了

这儿有个

神仙的讲解

我只需要贴代码就是了

#include<bits/stdc++.h>
using namespace std;
#define ll long long
inline int read(){
	char ch=getchar();
	int res=0;
	while(!isdigit(ch)) ch=getchar();
	while(isdigit(ch)) res=(res<<1)+(res<<3)+(ch^48),ch=getchar();
	return res;
}
int n,m,cnt,s,minn=1e9,fa[100005],f[100005][20],d1[100005][20],d2[100005][20],adj[100005],nxt[600005],to[600005],val[600005],dep[100005];
ll ans;
struct edge{
	int u,v,len,vis;
}e[300005];
inline void addedge(int u,int v,int w){
	nxt[++cnt]=adj[u],adj[u]=cnt,to[cnt]=v,val[cnt]=w;
	nxt[++cnt]=adj[v],adj[v]=cnt,to[cnt]=u,val[cnt]=w;
}
inline bool cmp(edge a,edge b){
	return a.len<b.len;
}
inline int find(int x){
	return fa[x]==x?x:fa[x]=find(fa[x]);
}
inline void dfs(int point,int fa)
{
    for(int i=1;i<=16;i++)
      if(dep[point] >= 1<<i)
      {
        f[point][i] = f[f[point][i-1]][i-1];
        d1[point][i] = max(d1[point][i-1] , d1[f[point][i-1]][i-1]);
        if(d1[point][i-1] == d1[f[point][i-1]][i-1]) d2[point][i] = max(d2[point][i-1] , d2[f[point][i-1]][i-1]);
        else
        {
          d2[point][i] = min(d1[point][i-1] , d1[f[point][i-1]][i-1]);
          d2[point][i] = max(d2[point][i] , d2[f[point][i-1]][i-1]);
        }
      }
      else break;
    for(int u=adj[point];u;u=nxt[u])
    {
      int e=to[u];
      if(e == fa) continue;
      f[e][0] = point;
      dep[e] = dep[point] + 1;
      d1[e][0] = val[u];
      dfs(e,point);
    }
}
inline int lca(int x,int y){
	if(dep[x]<dep[y]) swap(x,y);
	int del=dep[x]-dep[y];
	for(int i=17;i>=0;i--){
		if(del>=(1<<i)) del-=(1<<i),x=f[x][i];
	}
	if(x==y) return x;
	for(int i=17;i>=0;i--){
		if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
	}
	return f[x][0];
}
inline void search(int u,int fa,int len){
	int max1=0,max2=0,del=dep[u]-dep[fa];
	for(int i=17;i>=0;i--){
		if(del>=(1<<i)){
			del-=1<<i;
			if(d1[u][i]>max1)max2=max1,max1=d1[u][i];
			max2=max(max2,d2[u][i]);
			u=f[u][i];
		}
	}
	if(len!=max1)minn=min(minn,len-max1);
	else minn=min(minn,len-max2);
}
inline void solve(int x,int y,int z){
	int g=lca(x,y);
	search(x,g,z),search(y,g,z);
}
int main(){
	n=read(),m=read();
	for(int i=1;i<=n;i++) fa[i]=i;
	for(int i=1;i<=m;i++){
		e[i].u=read(),e[i].v=read(),e[i].len=read();
	}
	sort(e+1,e+m+1,cmp);
	for(int i=1;i<=m;i++){
		int f1=find(e[i].u),f2=find(e[i].v);
		if(f1!=f2){
			ans+=e[i].len;
			fa[f1]=f2;
			e[i].vis=1;
			addedge(e[i].u,e[i].v,e[i].len);
			s++;
			if(s==n-1)break;
		}
	}
	dfs(1,0);
	for(int i=1;i<=m;i++){
		if(!e[i].vis) solve(e[i].u,e[i].v,e[i].len);
	}
	cout<<ans+minn<<'\n';
}

猜你喜欢

转载自blog.csdn.net/qq_42555009/article/details/83147914
今日推荐