题目大意
给定数列
,设其排列后的数列为
,要求对于任意的
,满足
。某个合法排列的价值为
,求最大价值。
题解
神题qaq。
首先,好好思考题目中那个乱七八糟的条件是啥,也就是说让
的
值出现在第
个之后。我们考虑最后
的排列方式,显然对于任意
,使
的
值出现在
的
值之前。
于是我们就得到了若干个限制
数列的关系,于是对于每个
,从
向
连一条边,得到了一个图。如果这个图中存在环的话显然无解,否则这个图就是以0为根的树。
考虑
最小的那个点,显然如果它的father被选了,下一个选的必然是它。因此它俩必然在数列中连续,于是把它们用并查集合起来,计算贡献。
然而这样操作之后就变成了多个连通块比较,我们如何选择最小的连通块呢?
考虑任意两个连通块
,如果
在
之前更优,则必然有
,消一下就变成了
。
于是用优先队列维护最小值就行了,复杂度
。
#include <bits/stdc++.h>
namespace IOStream {
const int MAXR = 10000000;
char _READ_[MAXR], _PRINT_[MAXR];
int _READ_POS_, _PRINT_POS_, _READ_LEN_;
inline char readc() {
#ifndef ONLINE_JUDGE
return getchar();
#endif
if (!_READ_POS_) _READ_LEN_ = fread(_READ_, 1, MAXR, stdin);
char c = _READ_[_READ_POS_++];
if (_READ_POS_ == MAXR) _READ_POS_ = 0;
if (_READ_POS_ > _READ_LEN_) return 0;
return c;
}
template<typename T> inline void read(T &x) {
x = 0; register int flag = 1, c;
while (((c = readc()) < '0' || c > '9') && c != '-');
if (c == '-') flag = -1; else x = c - '0';
while ((c = readc()) >= '0' && c <= '9') x = x * 10 - '0' + c;
x *= flag;
}
template<typename T1, typename ...T2> inline void read(T1 &a, T2&... x) {
read(a), read(x...);
}
inline int reads(char *s) {
register int len = 0, c;
while (isspace(c = readc()) || !c);
s[len++] = c;
while (!isspace(c = readc()) && c) s[len++] = c;
s[len] = 0;
return len;
}
inline void ioflush() { fwrite(_PRINT_, 1, _PRINT_POS_, stdout), _PRINT_POS_ = 0; fflush(stdout); }
inline void printc(char c) {
if (!c) return;
_PRINT_[_PRINT_POS_++] = c;
if (_PRINT_POS_ == MAXR) ioflush();
}
inline void prints(const char *s, char c) {
for (int i = 0; s[i]; i++) printc(s[i]);
printc(c);
}
template<typename T> inline void print(T x, char c = '\n') {
if (x < 0) printc('-'), x = -x;
if (x) {
static char sta[20];
register int tp = 0;
for (; x; x /= 10) sta[tp++] = x % 10 + '0';
while (tp > 0) printc(sta[--tp]);
} else printc('0');
printc(c);
}
template<typename T1, typename ...T2> inline void print(T1 x, T2... y) {
print(x, ' '), print(y...);
}
}
using namespace IOStream;
using namespace std;
typedef long long ll;
typedef pair<int, int> P;
const int MAXN = 500005;
struct Node {
ll sum; int sz, rt;
bool operator<(const Node &nd) const { return sum * nd.sz > sz * nd.sum; }
};
priority_queue<Node> pq;
int ww[MAXN], par[MAXN], sz[MAXN], fa[MAXN], n;
ll sum[MAXN];
int find(int x) { return x == par[x] ? x : par[x] = find(par[x]); }
void merge(int x, int y) {
x = find(x), y = find(y);
if (x == y) return;
par[x] = y, sz[y] += sz[x], sum[y] += sum[x];
}
int main() {
read(n);
for (int i = 0; i <= n; i++) par[i] = i;
for (int i = 1; i <= n; i++) {
int t; read(t); fa[i] = t;
if (find(i) == find(t)) return puts("-1") * 0;
par[find(i)] = find(t);
}
ll res = 0;
for (int i = 1; i <= n; i++) {
par[i] = i, sz[i] = 1; read(sum[i]);
pq.push((Node) { sum[i], 1, i });
}
par[0] = 0, sz[0] = 1;
for (int i = 1; i <= n; i++) {
Node d = pq.top(); pq.pop();
if (!d.rt || sz[d.rt] != d.sz) { --i; continue; }
int p = find(fa[d.rt]);
res += sum[d.rt] * sz[p];
merge(d.rt, p);
pq.push((Node) { sum[p], sz[p], p });
}
printf("%lld\n", res);
return 0;
}