题解 P4071 【[SDOI2016]排列计数】 (费马小定理求组合数 + 错排问题)

luogu题目传送门!

luogu博客通道!

这题要用到错排,先理解一下什么是错排:

问题:有一个数集A,里面有n个元素 a[i]。求,如果将其打乱,有多少种方法使得所有第原来的i个数a[i]不在原来的位置上。

可以简单这么理解:

数集(初始)
1

2

3

4

5

6

错排转化后(一种情况):

2

1

4

3

6

5

于是,我们设f[i]为数集中有总共i个数时,其错排的方案数有多少。 那么,经过大量的手摸, 我们来求一下递推式:

f[0] = f[2] = 1, f[1] = 0,这些是显而易见的。当i大于3以后,假设存在一个数字k,我们手摸一下n出现在第k位的情况,发现会有以下两种:

1、数字n刚好在第k位,则我们要求的就是剩下n - 2个数的错排。即f[i - 2]\ 2、数字n不在第k位,则我们要求的就是n - 1个数的错排,即f[i - 1] 又由于我们的k是属于区间[1, n)的,又有n - 1种取值。\ 所以,我们要再乘上n - 1.即为 f[i] = (i - 1) * (f[i - 1] + f[i - 2])

回归本题 。 题目翻译:求在长为n的全排列中,第i位恰好是i,且满足条件的个数刚好有m个。 如果反过来看,就是把排列中的m个数抽出来,使得剩下的n - m个数全都不在自己的位置上。

那么答案就很明显了,ans = C(n, m) * f[n - m];

代码来一波:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define N 1000010
#define isdigit(c) ((c)>='0'&&(c)<='9')
const ll mod = (int)1e9 + 7;

inline ll read(){
    ll x = 0;
    char c = getchar();
    while(!isdigit(c)){
        c = getchar();
    }
    while(isdigit(c)){
        x = (x << 1) + (x << 3) + (c ^ '0');
        c = getchar();
    }
    return x;
}

ll n, m;
struct node{
    ll l, r;
} t[N];

ll inv[N], fac[N];
ll maxn = 0, maxm = 0;
ll f[N]; 

ll pow(ll a,ll b){//求a的 b次方 
    ll s = 1,temp = a;
    while(b){
        if(b & 1)s = (s * temp) % mod;
        temp = (temp * temp) % mod;
        b >>= 1;
    }
    return s % mod;
}

inline ll C(ll n, ll m){
    if(n < m) return 0;
    else return inv[n] * fac[m] % mod * fac[n-m] % mod;
}

void prepare(){
    inv[0] = fac[0] = 1;
    for(int i = 1;i <= maxn;i++)
        inv[i] = inv[i-1] * i % mod;
    fac[maxn] = pow(inv[maxn], mod - 2) % mod;//费马小定理求逆元
    for(int i = maxn - 1; i;i--)
        fac[i] = fac[i + 1] * (i + 1) % mod;  /*以上均是组合数求解*/
    f[1] = 0, f[2] = 1, f[0] = 1;
    for(int i = 3;i <= maxn; i++){
        f[i] = ((i - 1) * (f[i - 1] + f[i - 2] % mod)) % mod; /*错排处理*/
    }
    return ;
}

int main(){
    ll T = read();
    for(int i = 1;i <= T; i++){
        n = read(), m = read();
        t[i] = (node){n, m}; 
        maxn = max(maxn, n);/*先找最大值可以降低某些点的复杂度*/ 
    }
    prepare();
    for(int i = 1;i <= T; i++){
        printf("%lld\n", C(t[i].l, t[i].r) * f[t[i].l - t[i].r] % mod);

    } 
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/wondering-world/p/12822149.html
今日推荐