Codeforces1444 B. Divide and Sum(组合数学)

题意:

在这里插入图片描述
数据范围:n<=2e5

解法:

∑ i = 1 n ∣ p ( i ) − q ( i ) ∣ = ∑ i = 1 n m a x ( p ( i ) , q ( i ) ) − m i n ( p ( i ) , q ( i ) ) m a x ( p i , q i ) 一 定 是 前 n 大 的 数 , m i n ( q , p i ) 一 定 是 前 n 小 的 数 . 如 果 不 是 这 样 : 1. 某 个 m a x ( p i , q i ) 和 m i n ( p i , q i ) 都 是 前 n 大 , 那 么 会 导 致 某 个 m a x ( p j , q j ) 和 m i n ( p j , q j ) 都 是 前 n 小 , 这 样 的 匹 配 在 p 非 递 减 , q 非 递 增 的 排 序 方 式 下 是 不 存 在 的 . 2. 某 个 m a x ( p i , q i ) 和 m i n ( p i , q i ) 都 是 前 n 小 , 那 么 会 导 致 某 个 m a x ( p j , q j ) 和 m i n ( p j , q j ) 都 是 前 n 大 . 这 样 的 匹 配 在 p 非 递 减 , q 非 递 增 的 排 序 方 式 下 是 不 存 在 的 . 因 此 无 论 怎 么 排 列 , f 函 数 的 值 = 前 n 大 − 前 n 小 . 总 排 列 数 为 C ( 2 n , n ) , 乘 上 f 函 数 的 值 就 是 答 案 . \sum_{i=1}^n|p(i)-q(i)|=\sum_{i=1}^nmax(p(i),q(i))-min(p(i),q(i))\\ max(pi,qi)一定是前n大的数,min(q,pi)一定是前n小的数.\\ 如果不是这样:\\ 1.某个max(pi,qi)和min(pi,qi)都是前n大,那么会导致某个max(pj,qj)和min(pj,qj)都是前n小,\\ 这样的匹配在p非递减,q非递增的排序方式下是不存在的.\\ 2.某个max(pi,qi)和min(pi,qi)都是前n小,那么会导致某个max(pj,qj)和min(pj,qj)都是前n大.\\ 这样的匹配在p非递减,q非递增的排序方式下是不存在的.\\ 因此无论怎么排列,f函数的值=前n大-前n小.\\ 总排列数为C(2n,n),乘上f函数的值就是答案. i=1np(i)q(i)=i=1nmax(p(i),q(i))min(p(i),q(i))max(pi,qi)n,min(q,pi)n.:1.max(pi,qi)min(pi,qi)n,max(pj,qj)min(pj,qj)n,p,q.2.max(pi,qi)min(pi,qi)n,max(pj,qj)min(pj,qj)n.p,q.,f=nn.C(2n,n),f.
解释一下为什么不存在:
在这里插入图片描述
假设上图中的绿点是两个非法匹配点,假设是两个前n大,
那么根据排序规则,p中的绿点一定在右边,q中的绿点一定在左边,
因为前n大只有n个,所以这两个点一定不会对应上,即不存在这样的匹配。

code:

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxm=1e6+5;
const int mod=998244353;
int fac[maxm],inv[maxm];
int a[maxm];
int n;
int ppow(int a,int b,int mod){
    
    
    int ans=1%mod;a%=mod;
    for(;b;b>>=1,a=a*a%mod)if(b&1)ans=ans*a%mod;
    return ans;
}
void init(){
    
    
    fac[0]=1;
    for(int i=1;i<maxm;i++)fac[i]=fac[i-1]*i%mod;
    inv[maxm-1]=ppow(fac[maxm-1],mod-2,mod);
    for(int i=maxm-2;i>=0;i--)inv[i]=inv[i+1]*(i+1)%mod;
}
int C(int n,int m){
    
    
    if(m<0||m>n)return 0;
    return fac[n]*inv[m]%mod*inv[n-m]%mod;
}
signed main(){
    
    
    ios::sync_with_stdio(0);
    init();
    cin>>n;
    for(int i=1;i<=n*2;i++)cin>>a[i];
    sort(a+1,a+1+n*2);
    int mis=0,mas=0;
    for(int i=1;i<=n;i++){
    
    
        mis=(mis+a[i])%mod;
    }
    for(int i=n+1;i<=n*2;i++){
    
    
        mas=(mas+a[i])%mod;
    }
    int ans=C(n*2,n)*(mas-mis)%mod;
    ans=(ans%mod+mod)%mod;
    cout<<ans<<endl;
    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_44178736/article/details/112792923