2020 Multi-University Training Contest 2 - E. New Equipments(网络流)

题目链接:New Equipments

题意:有n个工人,m台机器,每个工人有三个属性$a_i,b_i,c_i$,现在要把工人安排到机器上工作,一个工人只能安排到一个机器,一个机器上也只能安排一个工人,把第$i$个工人安排到第$j$个机器上的代价为$a_i \times j^2 + b_i \times j + c_i$,分别求安排[1 $\cdots$ n]个工人的最小代价

思路:由于每个工人的代价为二次函数,我们可以三分求出代价的最小值,然后向两边扩展,总共找出n个代价最小的机器,将这个工人和这n台机器之间连一条容量为1,费用为$a_i \times j^2 + b_i \times j + c_i$的边,对每个工人都这么操作,然后跑费用流即可,由于要分别求出安排[1 $\cdots$ n]个工人的最小代价,所以在残余网络上跑n次spfa即可

参考:https://blog.csdn.net/qq_43906000/article/details/107545485

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <vector>
#include <queue>
#include <cmath>

using namespace std;

typedef long long ll;

const int N = 5010;
const int M = 10010;
const int NN = 55;
const ll INF = 1000000000000000010;

struct Edge {
    int to, nex;
    ll w, c;
};

Edge edge[2 * M];
int T, n, m, mx, p[NN][NN];
ll a[NN], b[NN], c[NN];
vector<int> alls;
int cnt, s, t, pre[N], lst[N], head[N], inq[N];
ll mc, dis[N], mf, f[N];
queue<int> q;

void init()
{
    mc = mf = 0;
    alls.clear();
    cnt = -1;
    memset(head, -1, sizeof(head));
}

void add_edge(int u, int v, ll w, ll c)
{
    edge[++cnt].to = v;
    edge[cnt].w = w;
    edge[cnt].c = c;
    edge[cnt].nex = head[u];
    head[u] = cnt;
}

int spfa(int s, int t)
{
    for (int i = 0; i < N; i++) {
        dis[i] = f[i] = INF;
        inq[i] = 0;
    }
    while (!q.empty()) q.pop();
    q.push(s);
    inq[s] = 1, dis[s] = 0, pre[t] = -1;
    while (!q.empty()) {
        int u = q.front();
        q.pop();
        inq[u] = 0;
        for (int i = head[u]; -1 != i; i = edge[i].nex) {
            int v = edge[i].to;
            ll c = edge[i].c, w = edge[i].w;
            if (0 == w || dis[v] <= dis[u] + c) continue;
            dis[v] = dis[u] + c;
            pre[v] = u;
            lst[v] = i;
            f[v] = min(f[u], w);
            if (1 == inq[v]) continue;
            inq[v] = 1;
            q.push(v);
        }
    }
    if (-1 == pre[t]) return 0;
    return 1;
}

void insert(int u, int v, ll w, ll c)
{
    add_edge(u, v, w, c);
    add_edge(v, u, 0, -c);
}

ll fx(int i, ll x)
{
    return x * x * a[i] + x * b[i] + c[i];
}

int fid(int x)
{
    return lower_bound(alls.begin(), alls.end(), x) - alls.begin() + 1;
}

int main()
{
    // freopen("in.txt", "r", stdin);
    // freopen("out.txt", "w", stdout);
    scanf("%d", &T);
    while (T--) {
        init();
        scanf("%d%d", &n, &mx);
        for (int i = 1; i <= n; i++) {
            scanf("%lld%lld%lld", &a[i], &b[i], &c[i]);
            int l = 1, r = mx, c = 1;
            while (l < r) {
                int lmid = floor(1.0 * (l + r) / 2);
                int rmid = floor(1.0 * (lmid + r) / 2);
                if (fx(i, lmid) <= fx(i, rmid)) r = rmid;
                else l = lmid;
            }
            p[i][c] = l;
            r = l + 1, l = l - 1;
            while (c < n) {
                if (fx(i, l) < fx(i, r)) {
                    if (l >= 1 && l <= mx) p[i][++c] = l--;
                    else p[i][++c] = r++;
                }
                else {
                    if (r >= 1 && r <= mx) p[i][++c] = r++;
                    else p[i][++c] = l--;
                }
            }
        }
        for (int i = 1; i <= n; i++)
            for (int k = 1; k <= n; k++) alls.push_back(p[i][k]);
        sort(alls.begin(), alls.end());
        alls.erase(unique(alls.begin(), alls.end()), alls.end());
        m = alls.size(), s = n + m + 1, t = s + 1;
        for (int i = 1; i <= n; i++) {
            insert(s, i, 1, 0);
            for (int k = 1; k <= n; k++) {
                insert(i, n + fid(p[i][k]), 1, fx(i, p[i][k]));
            }
        }
        for (int i = 1; i <= m; i++) insert(n + i, t, 1, 0);
        for (int i = 1; i <= n; i++) {
            if (!spfa(s, t)) break;
            mf += f[t];
            mc += f[t] * dis[t];
            int now = t;
            while (now != s) {
                edge[lst[now]].w -= f[t];
                edge[lst[now] ^ 1].w += f[t];
                now = pre[now];
            }
            printf("%lld", mc);
            printf(i == n ? "\n" : " ");
        }
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/zzzzzzy/p/13371030.html