6639. 【GDOI2020.5.16模拟】Minusk (MTT+多项式倍增点值)

Description:

https://gmoj.net/senior/#main/show/6639

题解:

考虑\(n!\)怎么算,经典做法:

\(v=\sqrt n\),当\(n~mod~v\neq 0\)时就\(n--\),最后加上这个的贡献就好了。

每一块可以看做\(\prod {i=1}^v (x+i)\),现在求\(x=(0..n/v-1)*v\)的点值,把多项式分治展开后多点求值即可。

时间复杂度:\(O(\sqrt n*log^2n)\)

这个题可以看做\(\frac{\sum_{i=1}^n \prod_{j \neq i} j}{n!}\),差不多的做。

当然这样就TLE了,这题卡的比较紧。

考虑zzq博客上的倍增点值做法:

假设有多项式\(A_n=\prod_{i=1}^n (x+i)\)

我们用一些\(n+1\)个点值来表示这个多项式,这题就用\((0..n)*v\)即可。

假设已知:\(A_n((0..n)*v)\)
思考如何求出\(A_n((0..2n)*v)\)

发现可以插值,设要求\(xv+c\)处的点值,\(c=(n+1)v,x \in [0,n]\)

\(A(xv+c)=\sum_{i=0}^n A(iv) \prod_{j \neq i} \frac{xv+c-jv}{iv-jv}\)

因为\(\forall j\in [0,n],xv+c-jv\neq 0\)

所以可以写成:

\(=(\prod_{j=0}^n (xv+c-jv))*(\sum_{i=0}^n \frac{(-1)^{n-i}}{i!*(n-i)!}*v^{-n}*\frac{1}{xv+c-iv})\)

前面可以写成区间积形式,处理前缀积和前缀逆元积即可。

后面的可以NTT。

知道了上面的变换,我们做两次就可以由\(A_n((0..n)v)\)推到\(A_{2n}((0..2n)v)\)

这样就可以倍增求点值了,不难发现复杂度是\(O(n~log~n)\)的。

这题就需要同时维护两个多项式,有点细节,还要MTT。

Code:

#include<bits/stdc++.h>
#define fo(i, x, y) for(int i = x, _b = y; i <= _b; i ++)
#define ff(i, x, y) for(int i = x, _b = y; i <  _b; i ++)
#define fd(i, x, y) for(int i = x, _b = y; i >= _b; i --)
#define ll long long
#define pp printf
#define hh pp("\n")
using namespace std;

ll n, k, mo;

ll ksm(ll x, ll y) {
	x %= mo;
	ll s = 1;
	for(; y; y /= 2, x = x * x % mo)
		if(y & 1) s = s * x % mo;
	return s;
}

#define db double
#define V vector<ll>
#define si size()
#define re resize

namespace mtt {
	const db pi = acos(-1);
	struct P {
		db x, y;
		P(db _x = 0, db _y = 0) { x = _x, y = _y;}
		P operator + (P b) { return P(x + b.x, y + b.y);}
		P operator - (P b){  return P(x - b.x, y - b.y);}
		P operator * (P b) { return P(x * b.x - y * b.y, x * b.y + y * b.x);}
	};
	
	const int nm = 1 << 19;
	
	P w[nm]; int r[nm];
	P c0[nm], c1[nm], c2[nm], c3[nm];
	
	void build() {
		for(int i = 1; i < nm; i *= 2) ff(j, 0, i)
			w[i + j] = P(cos(pi * j / i), sin(pi * j / i));
	}
	
	void dft(P *a, int n) {
		ff(i, 0, n) {
			r[i] = r[i / 2] / 2 + (i & 1) * (n / 2);
			if(i < r[i]) swap(a[i], a[r[i]]);
		} P b;
		for(int i = 1; i < n; i *= 2) for(int j = 0; j < n; j += 2 * i)
			ff(k, 0, i) b = a[i + j + k] * w[i + k], a[i + j + k] = a[j + k] - b, a[j + k] = a[j + k] + b;
	}
	void rev(P *a, int n) {
		reverse(a + 1, a + n);
		ff(i, 0, n) a[i].x /= n, a[i].y /= n;
	}
	P conj(P a) { return P(a.x, -a.y);}
	void fft(ll *a, ll *b, int n) {
		#define qz(x) ((ll) round(x))
		ff(i, 0, n) c0[i] = P(a[i] & 32767, a[i] >> 15), c1[i] = P(b[i] & 32767, b[i] >> 15);
		dft(c0, n); dft(c1, n);
		ff(i, 0, n) {
			P k, d0, d1, d2, d3;
			int j = (n - i) & (n - 1);
			k = conj(c0[j]);
			d0 = (k + c0[i]) * P(0.5, 0);
			d1 = (k - c0[i]) * P(0, 0.5);
			k = conj(c1[j]);
			d2 = (k + c1[i]) * P(0.5, 0);
			d3 = (k - c1[i]) * P(0, 0.5);
			c2[i] = d0 * d2 + d1 * d3 * P(0, 1);
			c3[i] = d0 * d3 + d1 * d2;
		}
		dft(c2, n); dft(c3, n); rev(c2, n); rev(c3, n);
		ff(i, 0, n) {
			a[i] = qz(c2[i].x) + (qz(c2[i].y) % mo << 30) + (qz(c3[i].x) % mo << 15);
			a[i] %= mo;
		}
	}
	ll a[nm], b[nm];
	V operator * (V p, V q) {
		int n0 = p.si + q.si - 1, n = 1;
		while(n < n0) n *= 2;
		ff(i, 0, n) a[i] = b[i] = 0;
		ff(i, 0, p.si) a[i] = p[i];
		ff(i, 0, q.si) b[i] = q[i];
		fft(a, b, n);
		p.re(n0);
		ff(i, 0, n0) p[i] = a[i];
		return p;
	}
}

using mtt :: operator *;

int v;

namespace sub1 {
	const int N = 1e6 + 5;
	
	ll fac[N], nf[N], f[N], vf[N];
	
	V func(V a, int c) {
		int n = a.si - 1;
		
		fac[0] = 1; fo(i, 1, n) fac[i] = fac[i - 1] * i % mo;
		nf[n] = ksm(fac[n], mo - 2); fd(i, n, 1) nf[i - 1] = nf[i] * i % mo;
		
		f[0] = -n * v + c;
		fo(i, 1, 2 * n) f[i] = f[i - 1] * ((i - n) * v + c) % mo;
		vf[2 * n] = ksm(f[2 * n], mo - 2);
		fd(i, 2 * n, 1) vf[i - 1] = vf[i] * ((i - n) * v + c) % mo;
		
		ll inv_v = ksm(ksm(v, mo - 2), n);
		V p, q; p.re(n + 1); q.re(2 * n + 1);
		fo(i, 0, n) {
			p[i] = nf[i] * nf[n - i] % mo * ((n - i) % 2 ? -1 : 1) * inv_v % mo * a[i] % mo;
		}
		fo(i, 0, 2 * n) {
			q[i] = ksm((i - n) * v + c, mo - 2);
		}
		p = p * q;
		V w; w.re(n + 1);
		fo(i, 0, n) {
			w[i] = f[i + n];
			if(i > 0) w[i] = w[i] * vf[i - 1] % mo;
			w[i] = w[i] * p[i + n] % mo;
		}
		return w;
	}
}

using sub1 :: func;

#define pvv pair<V, V>
#define fs first
#define se second

pvv ch2(pvv a) {
	V b = a.fs, c = a.se;
	
	int n = b.si - 1;
	V d = func(b, (n + 1) * v);
	b.re(2 * n + 1);
	fo(i, n + 1, 2 * n) b[i] = d[i - (n + 1)];
	d = func(b, n / k);
	
	V e = func(c, (n + 1) * v);
	c.re(2 * n + 1);
	fo(i, n + 1, 2 * n) c[i] = e[i - (n + 1)];
	e = func(c, n / k);
	
	fo(i, 0, 2 * n) c[i] = (c[i] * d[i] + e[i] * b[i]) % mo;
	
	fo(i, 0, 2 * n) b[i] = b[i] * d[i] % mo;
	
	return pvv(b, c);
}

pvv jia1(pvv a) {
	int n = a.fs.si - 1;
	int m = n / k;
	V b = a.fs, c = a.se;
	b.re(n + 1 + k);
	fo(i, 0, n) b[i] = b[i] * ksm((i * v + m + 1), k) % mo;
	fo(j, 1, k) {
		b[n + j] = 1;
		fo(i, 1, m + 1) b[n + j] = b[n + j] * ksm((n + j) * v + i, k) % mo;
	}
	c.re(n + 1 + k);
	fo(i, 0, n) c[i] = (c[i] * ksm((i * v + m + 1), k) % mo + a.fs[i]) % mo;
	fo(j, 1, k) {
		ll s1 = 1, s2 = 0;
		fo(i, 1, m + 1) {
			ll w = ksm((n + j) * v + i, k) % mo;
			s2 = (s2 * w + s1) % mo;
			s1 = s1 * w % mo;
		}
		c[n + j] = s2;
	}
	return pvv(b, c);
}

pvv solve(int n) {
	if(n == 1) {
		V a; a.re(k + 1);
		fo(i, 0, k) a[i] = ksm(i * v + 1, k);
		V b; b.re(k + 1);
		fo(i, 0, k) b[i] = 1;
		return pvv(a, b);
	}
	pvv a = solve(n / 2);
	a = ch2(a);
	if(n % 2 == 1) a = jia1(a);
	return a;
}

const int N = 1e6 + 5;

ll p[N], q[N];

ll calc() {
	pvv a = solve(v);
	int m = n / v - 1;
	
	ll xs = 1;
	fo(i, 0, m) xs = xs * ksm(a.fs[i], mo - 2) % mo;
	
	p[0] = a.fs[0];
	fo(i, 1, m) p[i] = p[i - 1] * a.fs[i] % mo;
	q[m + 1] = 1;
	fd(i, m, 0) q[i] = q[i + 1] * a.fs[i] % mo;
	ll s = 0;
	fo(i, 0, m) {
		ll xs = 1;
		if(i > 0) xs = xs * p[i - 1] % mo;
		xs = xs * q[i + 1] % mo;
		s = (s + xs * a.se[i]) % mo;
	}
	s = s * xs % mo;
	return s;
}

int main() {
	freopen("minusk.in", "r", stdin);
	freopen("minusk.out", "w", stdout);
	mtt :: build();
	scanf("%lld %lld %lld", &n, &k, &mo);
	if(n <= 1e6) {
		ll ans = 0;
		fo(i, 1, n) ans = (ans + ksm(ksm(i, k), mo - 2)) % mo;
		pp("%lld\n", ans);
		return 0;
	}
	v = max(1, (int) ceil(sqrt((double) n / k)));
	ll ans = 0;
	while(n > 0 && n % v != 0) {
		ans = (ans + ksm(ksm(n, k), mo - 2)) % mo;
		n --;
	}
	ans = (ans + calc() + mo) % mo;
	pp("%lld\n", ans);
}

猜你喜欢

转载自www.cnblogs.com/coldchair/p/12932744.html