「ZJOI2019」开关

考虑进行容斥。记 F ( z ) F(z) 是先假设一直按下去, n n 次后恰好所有灯都达到状态的概率的指母函数,则有

F ( z ) = i = 1 n e p i z + ( 1 ) s i e p i z 2 F(z) = \prod_{i=1}^n \frac{\mathrm e^{p_iz} + (-1)^{s_i} \mathrm e^{-p_iz}}2

注意此处的 p i p_i 是概率,即原本的 p i p i \frac{p_i}{\sum p_i}

而我们想要扣掉第一次达到状态后再次返回到原状态的情况,记 G ( z ) G(z) 是回到原点的概率的指母函数

G ( z ) = i = 1 n e p i z + e p i z 2 G(z) = \prod_{i=1}^n \frac{\mathrm e^{p_iz} + \mathrm e^{-p_iz}}2

如果其对应的 OGF f ( z ) = 0 + F ( z t ) e t d t f(z) = \int_0^{+\infty} F(zt)\mathrm e^{-t} \mathrm d t g ( z ) g(z) ,显然 h = f / g h=f/g 就是第一次到达状态的概率的母函数, h ( 1 ) h'(1) 就是答案。

F ( z ) = a w e w z F(z) = \sum a_w \mathrm e^{wz} ,则 f ( z ) = a w 1 w z f(z) = \sum \frac{a_w}{1- wz} ,因此我们只需要背包出 F F 的表示方法,考虑 h = f g g f g 2 h'=\frac{f'g-g'f}{g^2} ,但是 f , g f, g 中含有 1 1 z \frac 1{1-z} 项,应当上下乘以 1 z 1-z ,然后计算即可。推导可得 ( 1 z 1 w z ) z = 1 = 1 w 1 \left.\left( \frac{1-z}{1-wz} \right)'\right|_{z=1} = \frac1{w-1}

M = p i M = \sum p_i ,复杂度为 Θ ( n M ) \Theta(nM) ,若用分治 FFT 或者多项式 exp 优化则可以做到 Θ ( M log 2 M ) \Theta(M\log^2 M) Θ ( M log M ) \Theta(M\log M)

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <cctype>

#include <algorithm>
#include <random>
#include <bitset>
#include <queue>
#include <functional>
#include <set>
#include <map>
#include <vector>
#include <chrono>
#include <iostream>
#include <limits>
#include <numeric>

#define LOG(FMT...) fprintf(stderr, FMT)

using namespace std;

typedef long long ll;
typedef unsigned long long ull;

const int N = 110, M = 50010, P = 998244353;

int n, m;
int s[N], p[N], a[M], b[M];

int norm(int x) { return x >= P ? x - P : x; }

void exGcd(int a, int b, int& x, int& y) {
  if (!b) {
    x = 1;
    y = 0;
    return;
  }
  exGcd(b, a % b, y, x);
  y -= a / b * x;
}

int inv(int a) {
  int x, y;
  exGcd(a, P, x, y);
  return norm(x + P);
}

int calc(int* arr) {
  int ret = 0;
  for (int i = 0; i < m; ++i) {
    int q = norm(2LL * i * inv(m) % P + P - 1);
    ret = (ret + arr[i] * (ll)inv(norm(P + q - 1))) % P;
  }
  return ret;
}

int main() {
  scanf("%d", &n);
  for (int i = 1; i <= n; ++i)
    scanf("%d", &s[i]);
  for (int i = 1; i <= n; ++i)
    scanf("%d", &p[i]);
  a[0] = 1;
  b[0] = 1;
  for (int i = 1; i <= n; ++i) {
    m += p[i];
    if (s[i] == 1) {
      for (int j = m; j >= p[i]; --j)
        a[j] = norm(P + a[j - p[i]] - a[j]);
      for (int j = p[i] - 1; j >= 0; --j)
        a[j] = norm(P - a[j]);
    } else
      for (int j = m; j >= p[i]; --j)
        a[j] = norm(a[j - p[i]] + a[j]);
    for (int j = m; j >= p[i]; --j)
      b[j] = norm(b[j - p[i]] + b[j]);
  }
  int f = calc(a), g = calc(b);
  int ans = norm(f + P - g);
  printf("%d\n", ans);
  return 0;
}

猜你喜欢

转载自blog.csdn.net/EI_Captain/article/details/89600317