洛谷2619/bzoj2654 Tree(凸优化+MST)

bzoj的数据是真的水。。
qwq
由于本人还有很多东西不是很理解
qwq
所以这里只写一个正确的做法。

首先,我们会发现,对于你选择白色边的数目,随着数目的上涨,斜率是单调升高的。

那么这时候我们就可以考虑凸优化,也就是\(wqs\)二分来满足题目中所述的正好\(k\)条边的限制。

我们\(erf\)一个\(mid\),然后让每一个白边的权值都加上\(mid\),然后跑\(MST\),看最后的选的白色边数,是否是大于等于\(k\)的,如果是,就调大\(l\),否则调小\(r\)

由于最小生成树选择边的时候可能有一些玄学的错误,所以我们在\(sort\)的时候,对于权值相等的边,我们优先选择白边。

那么通过\(erf\),之后,我们就能得到一个上界,也就是在当前的偏移量下,我们最多的选和1相连的边的个数。

根据\(clj\)的官方题解,这里有两个引理

对于一个图,如果存在一个最小生成树,它的白边的数量是\(x\),那么就称\(x\)是最小合法白边数。所有的最小合法白边数形成一个区间\([l,r]\)
(因为题目保证有解,所以我们只需要找到最小的\(r\)即可)

那么经过这个\(erf\),我们就能得到一个最小的\(r\)

那么我们应该怎么求整个\(MST\)的权值呢,我们会发现,对于权值相等的白边和黑边,由于题目保证有解,所以一定是会存在相互替代的关系的。
那我们可以按照之前的最小生成树的策略选白边,将其记为\(val\),最后输出\(val-k*ans\)\(ans\)表示最后的\(mid\)
为什么是\(k\)而不是具体的选的边的数目呢?

因为题目要求正好选择\(k\)条,而我们这里实际上是把多余的白边都直接视为黑边来做了
qwqwq
那么这个题就能解决了
qwqwqwqwq
但是我根据CF125E那个题,有一个比较特殊的做法,但是套到这个这个题,我并不是很理解。qwq
这个坑还是之后再填吧

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<vector>
#include<map>
#include<vector>
#define mk make_pair
#define pb push_back
#define ll long long
#define int long long
using namespace std;
inline int read()
{
   int x=0,f=1;char ch=getchar();
   while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
   while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
   return x*f;
} 
const int maxn = 4e5+1e2;
struct Edge{
    int u,v,w;
    int col;
}; 
Edge e[maxn];
int n,m;
int ans;
int l=-200,r=200;
int fa[maxn];
int find(int x)
{
    if (fa[x]!=x) fa[x]=find(fa[x]);
    return fa[x];
}
int k;
bool cmp(Edge a,Edge b)
{
    if (a.w==b.w) return a.col<b.col;
    return a.w<b.w;
} 
int solve()
{
    sort(e+1,e+1+m,cmp);
    int tot=0;
    for (int i=1;i<=m;i++)
    {
        int f1 = find(e[i].u);
        int f2 = find(e[i].v);
        if (f1==f2) continue;
        //if(tot==k && e[i].col==0) continue;
        if (e[i].col==0) ++tot;
        fa[f1]=fa[f2];
    }
    return tot;
}
signed main()
{
  n=read(),m=read();k=read();
  for (int i=1;i<=m;i++)
  {
    e[i].u=read()+1;
    e[i].v=read()+1;
    e[i].w=read();
    e[i].col=read();
  }
  while(l<=r)
  {
     int mid = (l+r) >> 1;
     for (int i=1;i<=n;i++) fa[i]=i;
     for (int i=1;i<=m;i++)
     {
        if (e[i].col==0) e[i].w+=mid; 
     }
     int tmp = solve();
     if (tmp<k)
     {
        r=mid-1;
     }
     else l=mid+1,ans=mid;
     for (int i=1;i<=m;i++) 
     {
        if (e[i].col==0) e[i].w-=mid;
     }
  }
  for (int i=1;i<=n;i++) fa[i]=i;
  for (int i=1;i<=m;i++)
  if (e[i].col==0) e[i].w+=ans;
  sort(e+1,e+1+m,cmp);
  int tot=0,val=0;
  for (int i=1;i<=m;i++)
 {
        int f1 = find(e[i].u);
        int f2 = find(e[i].v);
        if (f1==f2) continue;
        if (e[i].col==0) ++tot;
        fa[f1]=fa[f2];
        val+=e[i].w;
  }
  cout<<val-k*ans;
  return 0;
}

猜你喜欢

转载自www.cnblogs.com/yimmortal/p/10202290.html