版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_39972971/article/details/89814560
【题目链接】
【思路要点】
- 所求的 即为 前 项的系数。
- 记 ,那么 ,其中 的次数为 ,其中 的次数为 。
- 考虑计算模数为 时的 ,最后用中国剩余定理合并答案。
- 比较 两端的系数,有 。
- 找到最小的 ,使得 不是 的倍数,则有
- 由于 不是 的倍数,我们可以将等式两侧乘以其乘法逆元。
- 可以发现,对于 ,上式中 前的系数均为 的倍数,因此,若让等式右侧的 从高次向低次再次代入上式,可以使得对于 ,得式中 前的系数均为 的倍数,重复 次即可让 的计算式中只包含 中的项,以及 中次数小于 的项,从而直接计算答案。
- 计算模数为 时的 的时间复杂度为 。
- 总时间复杂度 。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 4e6 + 5; const int MAXQ = 4e3 + 5; typedef long long ll; typedef long double ld; typedef unsigned long long ull; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } int n, m, d, P; vector <pair <int, int>> p; int power(int x, int y) { if (y == 0) return 1; int tmp = power(x, y / 2); if (y % 2 == 0) return 1ll * tmp * tmp % P; else return 1ll * tmp * tmp % P * x % P; } void exgcd(int a, int b, int &x, int &y) { if (b == 0) { x = 1, y = 0; return; } int q = a / b, r = a % b; exgcd(b, r, y, x); y -= q * x; } int inv(int x, int P) { int a = 0, b = 0; exgcd(x, P, a, b); return (a % P + P) % P; } int crt(int a, int P, int invP, int b, int Q, int invQ) { int Mod = P * Q; return (1ll * a * Q % Mod * invQ + 1ll * b * P % Mod * invP) % Mod; } int a[MAXN], b[MAXN], c[MAXN], res[MAXN]; void calcab(int ea, int eb, int m) { int lft = 1; vector <int> e(p.size()); for (int i = 0; i < ea && i < m + MAXQ; i++) { int tmp = ea - i, tnp = i + 1; for (unsigned i = 0; i < p.size(); i++) { while (tmp % p[i].first == 0) { tmp /= p[i].first; e[i]++; } while (tnp % p[i].first == 0) { tnp /= p[i].first; e[i]--; } } lft = 1ll * lft * tmp % P * inv(tnp, P) % P; int res = lft; for (unsigned i = 0; i < p.size(); i++) res = 1ll * res * power(p[i].first, e[i]) % P; a[i] = res; } lft = 1; e.clear(), e.resize(p.size()); for (int i = 0; i < eb; i++) { int tmp = eb - i, tnp = i + 1; for (unsigned i = 0; i < p.size(); i++) { while (tmp % p[i].first == 0) { tmp /= p[i].first; e[i]++; } while (tnp % p[i].first == 0) { tnp /= p[i].first; e[i]--; } } lft = 1ll * lft * tmp % P * inv(tnp, P) % P; int res = lft; for (unsigned i = 0; i < p.size(); i++) res = 1ll * res * power(p[i].first, e[i]) % P; b[i] = res; } } void factor(int x) { for (int i = 2; i * i <= x; i++) if (x % i == 0) { int cnt = 0; while (x % i == 0) { x /= i; cnt++; } p.emplace_back(i, cnt); } if (x != 1) p.emplace_back(x, 1); } void solve(int p, int k, int P) { int pos = 0; while (b[pos] % p == 0) pos++; int mul = inv(b[pos], P); static int coef[MAXQ], mula[MAXQ], func[MAXQ]; memset(coef, 0, sizeof(coef)); memset(mula, 0, sizeof(mula)); memset(func, 0, sizeof(func)); for (int i = 0; i <= d - 1; i++) { if (i != pos) coef[d - 1 - i] = func[d - 1 - i] = 1ll * mul * (P - b[i] % P) % P; else mula[d - 1] = mul; } pos = d - 1 - pos; int Limit = pos + d * (k + 1); for (int e = 2, v = p * p; e <= k; e++, v *= p) for (int i = Limit; i >= pos; i--) { if (coef[i] % v == 0) continue; int tmp = coef[i]; coef[i] = 0; mula[i + d - 1 - pos] = (mula[i + d - 1 - pos] + 1ll * tmp * mul) % P; for (int j = 0; j <= d - 1; j++) coef[i - pos + j] = (coef[i - pos + j] + 1ll * tmp * func[j]) % P; } while (Limit > pos && mula[Limit] == 0) Limit--; for (int i = 0; i <= m - 1; i++) { int res = 0; for (int j = pos; j <= Limit; j++) res = (res + 1ll * mula[j] * a[i + j - pos]) % P; for (int j = 1; j <= pos && j <= i; j++) res = (res + 1ll * coef[pos - j] * c[i - j]) % P; c[i] = (res + P) % P; } } int main() { read(n), read(m), read(d), read(P); factor(P), calcab(n * d, d, m); int Mod = 1; for (auto x : p) { int Q = power(x.first, x.second); if (p.size() == 1) Q = P; solve(x.first, x.second, Q); int invMod = inv(Mod % Q, Q); int invQ = inv(Q % Mod, Mod); for (int i = 0; i <= m - 1; i++) res[i] = crt(res[i], Mod, invMod, c[i], Q, invQ); Mod *= Q; } int ans = 0; for (int i = 0; i <= m - 1; i++) ans ^= res[i]; writeln(ans); return 0; }