[洛谷P3301][BZOJ3129][SDOI2013]方程(扩展Lucas+容斥)

Solution

  • 先考虑\(n_1=0\)的情况
  • 那么只要考虑形如\(X_i>=A_i\)的限制
  • 注意求的是正整数解的个数,即对于\(i>n_2\)\(X_i>=1(A_i=1)\)
  • \(\sum_{i=1}^{n}B_i=m\)非负整数解的个数为\(C(m+n-1,m)\)
  • 解释:序列共\(m+n-1\)个位置,选\(n-1\)个位置出来当隔板,把序列分为长度之和为\(m\)\(n\)段(可能存在长度为\(0\)的段,即隔板相邻的情况)
  • 现在为了满足这些限制,令\(B_i=X_i-A_i\),则\(B_i\)非负整数解的个数就是原题的合法解的个数
  • 那么\(m\)要减掉\(\sum_{i=1}^{n}A_i\)
  • 考虑\(n_1>0\)的情况,用总方案数\(-\)存在\(X_i>=A_i+1(1<=i<=n_1)\)的情况
  • 即考虑容斥:不考虑前\(n_1\)个数的限制的方案数\(-\)\(n_1\)个数至少有\(1\)个不满足条件的方案数\(+\)\(n_1\)个数至少有\(2\)个不满足条件的方案数\(-\)……
  • 发现\(n,m\)很大,但任意一组数据的\(p\)都可以拆成\(\Pi_{i=1}^{k}pi^{qi}\),且\(p_i<=10007\),那么用扩展\(lucas\)求组合数取模即可

Code

#include <bits/stdc++.h>

using namespace std;

#define ll long long

template <class t>
inline void read(t & res)
{
    char ch;
    while (ch = getchar(), !isdigit(ch));
    res = ch ^ 48;
    while (ch = getchar(), isdigit(ch))
    res = res * 10 + (ch ^ 48);
}

const int o = 2000;
int a[o], b[o], pk, p, c[o], d[o], tst, n1, n2, n, m, ans, h[o], now, f[20][10010];
bool vis[o];
ll tot;

inline int exgcd(int a, int b, int &x, int &y)
{
    if (!b)
    {
        x = 1;
        y = 0;
        return a;
    }
    int ret = exgcd(b, a % b, x, y), tmp = x;
    x = y;
    y = tmp - a / b * y;
    return ret;
}

inline int ksm(int x, ll y)
{
    int res = 1;
    while (y)
    {
        if (y & 1) res = (ll)res * x % pk;
        y >>= 1;
        x = (ll)x * x % pk;
    }
    return res;
}

inline int fac(int n, int p, int k)
{
    if (n == 1 || n == 0) return 1;
    ll cnt = n / p, bl = n / pk, res = fac(n / p, p, k), i, tmp;
    tot += cnt;
    tmp = f[now][pk - 1];
    tmp = ksm(tmp, bl);
    res = res * tmp % pk;
    res = res * f[now][n % pk] % pk;
    return res;
}

inline int solve(int n, int m, int id)
{
    int p = c[id], k = d[id];
    pk = a[id];
    tot = 0;
    int ra = fac(m, p, k); ll ta = tot;
    tot = 0;
    int rb = fac(n - m, p, k); ll tb = tot;
    tot = 0;
    int rc = fac(n, p, k); ll tc = tot;
    ll t = tc - ta - tb;
    if (t < 0) t = (t % k + k) % k;
    int ia, ib, xxx;
    exgcd(ra, pk, ia, xxx);
    exgcd(rb, pk, ib, xxx);
    if (ia < 0) ia += pk;
    if (ib < 0) ib += pk;
    return (ll)rc * ia % pk * ib % pk * ksm(p, t) % pk;
}

inline void init()
{
    int i, s = sqrt(p), lp = p, j;
    for (i = 2; i <= s; i++)
    if (lp % i == 0)
    {
        int t = 0, r = 1;
        while (lp % i == 0) 
        {
            t++;
            r *= i;
            lp /= i;
        }
        a[++a[0]] = r; 
        c[a[0]] = i;
        d[a[0]] = t;
    }
    if (lp != 1) 
    {
        a[++a[0]] = lp;
        c[a[0]] = lp;
        d[a[0]] = 1;
    }
    for (i = 1; i <= a[0]; i++)
    {
        f[i][0] = 1;
        for (j = 1; j <= a[i]; j++)
        if (j % c[i]) f[i][j] = (ll)f[i][j - 1] * j % a[i];
        else f[i][j] = f[i][j - 1];
    }
}

inline int cc(ll n, ll m, int p)
{
    if (n < m || m < 0) return 0;
    int ans = 0, i;
    for (i = 1; i <= a[0]; i++) 
    {
        now = i;
        b[i] = solve(n, m, i);
    }
    for (i = 1; i <= a[0]; i++)
    {
        int mi = p / a[i], g, y, aa = a[i];
        exgcd(mi, aa, g, y);
        ans = (ans + (ll)mi * g % p * b[i] % p + p) % p;
    }
    return ans;
}

inline void add(int &x, int y)
{
    x += y;
    if (x >= p) x -= p;
}

inline void pd()
{
    int i, tm = m, cnt = 0;
    for (i = 1; i <= n1; i++)
    if (vis[i])
    {
        cnt++;
        tm -= h[i] + 1;
    }
    else tm--;
    if (!cnt) return;
    if (cnt & 1) add(ans, p - cc(tm + n - 1, tm, p));
    else add(ans, cc(tm + n - 1, tm, p));
} 

inline void dfs(int k)
{
    if (k == n1 + 1)
    {
        pd();
        return;
    }
    vis[k] = 0;
    dfs(k + 1);
    vis[k] = 1;
    dfs(k + 1);
}

int main()
{
    int i;
    read(tst); read(p);
    init();
    while (tst--)
    {
        read(n); 
        read(n1); 
        read(n2);
        read(m);
        int tmp = n1 + n2;
        for (i = 1; i <= tmp; ++i) read(h[i]);
        m -= n - n1 - n2;
        for (i = n1 + 1; i <= n2 + n1; i++) m -= h[i];
        int tm = m - n1;
        ans = cc(tm + n - 1, tm, p);
        dfs(1);
        printf("%d\n", ans);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/cyf32768/p/12196441.html