BZOJ 1920 Luogu P4217 [CTSC2010]产品销售 (模拟费用流、线段树)

题目链接

(bzoj) https://www.lydsy.com/JudgeOnline/problem.php?id=1920
(luogu) https://www.luogu.org/problem/P4217

题解

模拟费用流。
首先可以建出下面这样的图:
对于每一天\(i\)建一个点,另新建源汇\(S,T\).

(1) \(S\)\(i\)\((D_i,0)\) (表示订单)

(2) \(i\)\(i+1\)\((+\inf,C_i)\) (拖延订单)

(3) \(i+1\)\(i\)\((+\inf,M_i)\) (拖延产品相当于把订单提前)

(4) \(i\)\(T\)\((U_i,P_i)\) (生产产品)
求最小费用最大流(也就是在与源点相连的边满流的情况下求最小费用)。

下面考虑如何模拟:
费用流的两条重要性质——与源点相连的边不会被退流;如果不是每次选一条全局最短路进行增广,而是按任意顺序枚举每条和源点相连的必须流的边进行增广,答案也是对的。
在费用流的建图中,我们要给每条横向边(\(i\)\(i+1\)之间的边)建立反向边,边权为正向边权的相反数。
假设\(a,b\)分别为边\(i\rightarrow i+1, i+1\rightarrow i\)的边权,若\(i\rightarrow i+1\)的流量不为\(0\), 那么当从\(i+1\)流到\(i\)时, 找最短路就会走那条费用为\(-a\)的边,同时给\(i\rightarrow i+1\)的流量\(-1\), 直到流量变为\(0\)为止。\(i+1\rightarrow i\)的边同理。
在这里有一个非常妙的思路——从左往右增广。
从左到右枚举每个和源点相连的点,分别尝试向左增广和向右增广,取代价较小者。
向右增广时,找到这个点出发的最长路,然后计算流量为当前点剩余流量与最长路的那个点到\(T\)剩余流量的最小值。
如果我们向右增广了\(x\)的流量,那么会导致增广路上每个点从右往左反向边的流量增加\(x\).
向左增广时,因为每条横边有一个“阈值”,初始时为在它左边的点向右增广给反向边增加的流量,流过的流量小于等于阈值时其边权为负,大于阈值时边权为正(注意这个阈值不同于容量,多于阈值的流完全可以流过去!我在这里错了好几次),因此我们考虑在保证所有边边权不变的情况下一次性增广,因此增广的流量为当前点剩余流量、最长路的那个点到\(T\)剩余流量、当前点和最长路上点阈值的最小值三者之最小值。
如果我们向左增广了\(x\)的流量,那么会导致增广路上每条边的阈值减少\(x\), 这时候如果某条边的阈值变成了\(0\), 那么要修改这条边的阈值。
最后,如果增广过程中某个和\(T\)相连的点流量变为\(0\), 还要将其删除。
所以我们要维护一个数据结构支持上述操作(具体维护方法下面再讲)。
下面分析复杂度,从左往右增广的妙处就是它能保证一条横向边的阈值在增广左侧的点的时候一直增加,增广右侧的点的时候一直减少,这样跨越\(0\)的次数就是常数次,边权变化也是常数次。
每次增广会导致一条与源点相连的边流完,或者一条与汇点相连的边流完,或者一条横向边边权改变,所以总操作次数为线性。

最后考虑如何用数据结构维护。
开三棵线段树,分别维护往左增广的费用、往右增广的费用以及中间横边的流量。
第二棵线段树,维护往右增广的费用,由于往右增广的边权是不变的,因此直接维护每个点到最右边点的从左往右边权和即可,支持删除(改为\(+\inf\))、查询最小值及其位置。
第一棵线段树也是如此,维护每个点到最左边点的从右往左边权和,支持修改、删除、查询最小值及其位置。
第三棵线段树,维护每条横边的流量。首先要支持区间加、查询最小值及其位置,然后是最棘手的操作——所有的位置初始为\(0\), 如果任何时刻任何位置经过先增加又减少之后变成了\(0\), 那么要在事件发生时快速枚举出这些位置,并在第一棵线段树上进行对应操作。这个我用的方法是pushdown的时候如果某个儿子区间最小值为\(0\), 那么就继续递归这个儿子,直到找到叶子节点,把为零的叶子节点赋成\(+\inf\), 并在第一棵线段树上做对应操作。但还有一个棘手的问题,就是如何区分初值\(0\)和经过修改之后先增加后减少变成了\(0\)? 我一开始使用的是如果标记的值不为\(0\)视为修改过,但是这样是错的,因为有可能若干次修改之后该位置上标记又变回了\(0\). 所以我又单独记录了一个新的标记\(f\), 表示是否修改过。当且仅当\(f\)为真且值为\(0\)时执行对应操作。

时间复杂度\(O(n\log n)\).

(我估计我的写法又麻烦了……QAQ)

代码

#include<cstdio>
#include<cstdlib>
#include<iostream>
#include<utility>
#include<cassert>
#define llong long long
#define pli pair<llong,int>
#define mkpr make_pair
using namespace std;
 
void read(int &x)
{
    int f=1;x=0;char s=getchar();
    while(s<'0'||s>'9'){if(s=='-')f=-1;s=getchar();}
    while(s>='0'&&s<='9'){x=x*10+s-'0';s=getchar();}
    x*=f;
}
 
const int N = 1e5;
const llong INF = 10000000000000ll;
llong a[N+3],b[N+3],w1[N+3],w2[N+3],c[N+3];
llong dl[N+3],dr[N+3];
int n;
 
struct SegmentTree1
{
    struct SgTNode
    {
        llong tag; pli mini;
        SgTNode() {mini = mkpr(INF,0); tag = 0;}
    } sgt[(N<<2)+3];
    void update(pli &x,pli y) {if(y.first<x.first) x = y;}
    void pushdown(int u)
    {
        assert(u<(N<<2));
        llong tag = sgt[u].tag;
        if(tag)
        {
            sgt[u<<1].mini.first += tag; sgt[u<<1].tag += tag;
            sgt[u<<1|1].mini.first += tag; sgt[u<<1|1].tag += tag;
            sgt[u].tag = 0;
        }
    }
    void pushup(int u)
    {
        assert(u<(N<<2));
        sgt[u].mini = mkpr(INF,0);
        update(sgt[u].mini,sgt[u<<1].mini);
        update(sgt[u].mini,sgt[u<<1|1].mini);
    }
    void build(int u,int le,int ri,llong a[])
    {
        assert(u<(N<<2));
        if(le==ri) {sgt[u].mini = mkpr(a[le],le); return;}
        int mid = (le+ri)>>1;
        build(u<<1,le,mid,a);
        build(u<<1|1,mid+1,ri,a);
        pushup(u);
    }
    void addval(int u,int le,int ri,int lb,int rb,llong x)
    {
        assert(u<(N<<2));
        if(le>=lb && ri<=rb) {sgt[u].mini.first += x; sgt[u].tag += x; return;}
        pushdown(u);
        int mid = (le+ri)>>1;
        if(lb<=mid) addval(u<<1,le,mid,lb,rb,x);
        if(rb>mid) addval(u<<1|1,mid+1,ri,lb,rb,x);
        pushup(u);
    }
    pli querymin(int u,int le,int ri,int lb,int rb)
    {
        assert(u<(N<<2));
        if(le>=lb && ri<=rb) {return sgt[u].mini;}
        pushdown(u);
        int mid = (le+ri)>>1; pli ret = mkpr(INF,0);
        if(lb<=mid) update(ret,querymin(u<<1,le,mid,lb,rb));
        if(rb>mid) update(ret,querymin(u<<1|1,mid+1,ri,lb,rb));
        pushup(u);
        return ret;
    }
} sgt1,sgt2;
 
struct SegmentTree2
{
    struct SgTNode
    {
        pli mini; llong tag; bool f;
        SgTNode() {mini = mkpr(0,0); tag = 0ll;}
    } sgt[(N<<2)+3];
    void update(pli &x,pli y) {if(y.first<x.first) x = y;}
    void pushup(int u)
    {
        assert(u<(N<<2));
        sgt[u].mini = mkpr(INF,0);
        update(sgt[u].mini,sgt[u<<1].mini);
        update(sgt[u].mini,sgt[u<<1|1].mini);
    }
    void pushdown(int u,int le,int ri)
    {
        assert(u<(N<<2));
        llong tag = sgt[u].tag;
        if(sgt[u].f)
        {
            int mid = (le+ri)>>1;
            if(le!=ri)
            {
                sgt[u<<1].mini.first += tag; sgt[u<<1].tag += tag; sgt[u<<1].f = true;
                sgt[u<<1|1].mini.first += tag; sgt[u<<1|1].tag += tag; sgt[u<<1|1].f = true;
            }
            sgt[u].tag = 0;
            if(sgt[u].mini.first!=0) return;
            if(le==ri) {sgt1.addval(1,1,n,1,le,b[le]+a[le]); sgt[u].mini.first = INF; return;}
            if(sgt[u<<1].mini.first==0) {pushdown(u<<1,le,mid);}
            if(sgt[u<<1|1].mini.first==0) {pushdown(u<<1|1,mid+1,ri);}
            pushup(u);
        }
    }
    void addval(int u,int le,int ri,int lb,int rb,llong x)
    {
        assert(u<(N<<2));
        if(le>=lb && ri<=rb) {sgt[u].mini.first += x; sgt[u].tag += x; sgt[u].f = true; pushdown(u,le,ri); return;}
        pushdown(u,le,ri); int mid = (le+ri)>>1;
        if(lb<=mid) addval(u<<1,le,mid,lb,rb,x);
        if(rb>mid) addval(u<<1|1,mid+1,ri,lb,rb,x);
        pushup(u);
    }
    pli querymin(int u,int le,int ri,int lb,int rb)
    {
        assert(u<(N<<2));
        if(le>=lb && ri<=rb) {return sgt[u].mini;}
        pushdown(u,le,ri); int mid = (le+ri)>>1; pli ret = mkpr(INF,0);
        if(lb<=mid) update(ret,querymin(u<<1,le,mid,lb,rb));
        if(rb>mid) update(ret,querymin(u<<1|1,mid+1,ri,lb,rb));
        pushup(u);
        return ret;
    }
} sgt3;
 
int main()
{
    scanf("%d",&n);
    for(int i=1; i<=n; i++) scanf("%lld",&w1[i]);
    for(int i=1; i<=n; i++) scanf("%lld",&w2[i]);
    for(int i=1; i<=n; i++) scanf("%lld",&c[i]);
    for(int i=1; i<n; i++) scanf("%lld",&b[i]);
    for(int i=1; i<n; i++) scanf("%lld",&a[i]);
    dl[1] = 0; for(int i=2; i<=n; i++) dl[i] = dl[i-1]+a[i-1];
    dr[n] = 0; for(int i=n-1; i>=1; i--) dr[i] = dr[i+1]-a[i];
    for(int i=1; i<=n; i++) dl[i]=dl[i]+c[i],dr[i]=dr[i]+c[i];
    sgt1.build(1,1,n,dr);
    sgt2.build(1,1,n,dl);
    for(int i=1; i<=n; i++) dl[i]-=c[i],dr[i]-=c[i];
    llong ans = 0ll;
    for(int i=1; i<=n; i++)
    {
        while(w1[i]>0)
        {
            pli vl = mkpr(INF,0),vr = mkpr(INF,0);
            pli tmp = sgt1.querymin(1,1,n,1,i);
            vl = mkpr(tmp.first-dr[i],tmp.second);
            tmp = sgt2.querymin(1,1,n,i,n);
            vr = mkpr(tmp.first-dl[i],tmp.second);
            if(vl.first<vr.first)
            {
                llong flow = min(w1[i],w2[vl.second]);
                if(vl.second<i)
                {
                    tmp = sgt3.querymin(1,1,n,vl.second,i-1);
                    flow = min(flow,tmp.first);
                    sgt3.addval(1,1,n,vl.second,i-1,-flow);
                }
                ans += flow*vl.first;
                w1[i] -= flow;
                w2[vl.second] -= flow;
                if(w2[vl.second]==0)
                {
                    sgt1.addval(1,1,n,vl.second,vl.second,INF);
                    sgt2.addval(1,1,n,vl.second,vl.second,INF);
                }
            }
            else
            {
                llong flow = min(w1[i],w2[vr.second]);
                ans += flow*vr.first;
                if(i<vr.second)
                {
                    sgt3.addval(1,1,n,i,vr.second-1,flow);
                }
                w1[i] -= flow;
                w2[vr.second] -= flow;
                if(w2[vr.second]==0)
                {
                    sgt1.addval(1,1,n,vr.second,vr.second,INF);
                    sgt2.addval(1,1,n,vr.second,vr.second,INF);
                }
            }
        }
        if(sgt3.querymin(1,1,n,i,i).first==0)
        {
            sgt3.addval(1,1,n,i,i,INF);
            sgt1.addval(1,1,n,1,i,b[i]+a[i]);
        }
    }
    printf("%lld\n",ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/suncongbo/p/11374733.html