传送门
题目大意
给你三个 到 的排列 , , 。
称三元组
是合法的,当且仅当存在一个下标集合
满足:
问合法三元组的数量。
。
思路
我们不妨先考虑只有两维的情况该怎么做。我们不妨将二元组以
为关键字排序,然后从小到大枚举
。对于二元组
,合法的
的个数为:
不妨看成:
也就是说,当维数为 时这个问题就是个二维偏序问题。
推广到三维,似乎没法推广到三维?不过注意到,在二维情况下合法的二元组只与一个或者两个下标有关。显然在三维情况下合法的三元组也只与至多三个下标有关,也就是包含至少一维最大值的下标。
当只与一个下标有关,也就是下标集合中只有一个元素时,方案数显然为 。当与两个下标集合有关,也就是下标集合中有两个元素时,发现对于两个下标,只要不是其中一个每一维都比另一个大,那么它们就是合法的,这就是一个三维偏序问题。
当下标集合大小为 时,要求每个元素都恰好有一个是三个中的最大值。显然这个东西没法直接套用偏序问题,我们不妨考虑计算不合法的方案数。第一种情况是只需要一个下标时用了三个下标,我们需要在三维都比它小的里面选两个,直接套用三维偏序即可;第二种情况是只需要两个下标时用了三个下标,即有一个其中两维最大,另一个其中一维最大,最后一个没有任何一维是最大值。
我们可以枚举其中两维,另一维不管,然后计算有多少个大小为三的下标集合满足某一个在这两维上都是最大值。把答案加起来后,可以发现算了三次三维都比另外两个大的大小为三的下标集合,算了一次我们真的想要的东西。三维都比另外两个大的大小为三的下标集合的数量我们已经算过了,所以我们就求得了我们想要的东西。
参考代码
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <cassert>
#include <cctype>
#include <climits>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <stack>
#include <queue>
#include <deque>
#include <map>
#include <set>
#include <bitset>
#include <list>
#include <functional>
using LL = long long;
using ULL = unsigned long long;
using std::cin;
using std::cout;
using std::endl;
using INT_PUT = LL;
INT_PUT readIn()
{
INT_PUT a = 0; bool positive = true;
char ch = getchar();
while (!(ch == '-' || std::isdigit(ch))) ch = getchar();
if (ch == '-') { positive = false; ch = getchar(); }
while (std::isdigit(ch)) { a = a * 10 - (ch - '0'); ch = getchar(); }
return positive ? -a : a;
}
void printOut(INT_PUT x)
{
char buffer[20]; int length = 0;
if (x < 0) putchar('-'); else x = -x;
do buffer[length++] = -(x % 10) + '0'; while (x /= 10);
do putchar(buffer[--length]); while (length);
putchar('\n');
}
const int maxn = int(1e5) + 5;
int n;
struct Triple
{
int a, b, c;
bool operator<(const Triple& y) const
{
return a < y.a;
}
} triples[maxn];
LL ans;
struct BIT
{
int c[maxn];
static inline int lowbit(int x) { return x & -x; }
void add(int pos, int val)
{
while (pos <= n)
{
c[pos] += val;
pos += lowbit(pos);
}
}
int query(int pos)
{
int ret = 0;
while (pos)
{
ret += c[pos];
pos ^= lowbit(pos);
}
return ret;
}
void clear(int pos)
{
while (pos <= n)
{
if (c[pos]) c[pos] = 0;
else break;
pos += lowbit(pos);
}
}
} bit;
int idx[maxn];
int temp[maxn];
int A[maxn];
void cdq3(int l, int r)
{
if (l == r)
return;
int mid = (l + r) >> 1;
cdq3(l, mid);
cdq3(mid + 1, r);
int i = l, j = mid + 1, k = l;
while (k <= r)
{
if (j > r || (i <= mid && triples[idx[i]].b <= triples[idx[j]].b))
{
bit.add(triples[idx[i]].c, 1);
temp[k++] = idx[i++];
}
else
{
LL t = bit.query(triples[idx[j]].c);
ans -= t;
A[idx[j]] += t;
temp[k++] = idx[j++];
}
}
for (i = l; i <= mid; i++)
bit.clear(triples[idx[i]].c);
for (i = l; i <= r; i++)
idx[i] = temp[i];
}
void run()
{
n = readIn();
for (int i = 1; i <= n; i++)
triples[i].a = readIn();
for (int i = 1; i <= n; i++)
triples[i].b = readIn();
for (int i = 1; i <= n; i++)
triples[i].c = readIn();
ans += n + (LL)n * (n - 1) / 2;
std::sort(triples + 1, triples + 1 + n);
for (int i = 1; i <= n; i++)
idx[i] = i;
cdq3(1, n);
LL X = 0;
for (int i = 1; i <= n; i++)
{
LL t = bit.query(triples[i].b - 1);
X += t * (t - 1) / 2;
bit.add(triples[i].b, 1);
}
std::memset(bit.c, 0, sizeof(bit.c));
for (int i = 1; i <= n; i++)
{
LL t = bit.query(triples[i].c - 1);
X += t * (t - 1) / 2;
bit.add(triples[i].c, 1);
}
std::memset(bit.c, 0, sizeof(bit.c));
std::sort(triples + 1, triples + 1 + n,
[](const Triple& x, const Triple& y)
{
return x.b < y.b;
});
for (int i = 1; i <= n; i++)
{
LL t = bit.query(triples[i].c - 1);
X += t * (t - 1) / 2;
bit.add(triples[i].c, 1);
}
for (int i = 1; i <= n; i++)
ans += (LL)A[i] * (A[i] - 1);
printOut(ans + (LL)n * (n - 1) * (n - 2) / 6 - X);
}
int main()
{
#ifndef LOCAL
freopen("subset.in", "r", stdin);
freopen("subset.out", "w", stdout);
#endif
run();
return 0;
}
总结
首先要想到把三元组对应到大小为 1 / 2 / 3 的下标集合上去,然后要想到如何计算直接用三维偏序无法计算的部分。