uoj#209【UER #6】票数统计

题目

做UER的A题涨信心

首先我们注意到这个所谓的至少有一条正确在\(x\)\(y\)不相等的时候非常弱,当\(x<y\)时,只有可能是后\(y\)位用户有\(x\)个通过;当\(x>y\)时,只有可能是前\(x\)位用户有\(y\)个通过。也就是说这些信息都能被转化成一些用来限制前后缀和的信息。

\(pre_i\)表示序列的前缀和,对于一条前\(x\)位用户有\(y\)个通过的限制,我们可以拆成\(pre_x=y\);对于一条后\(y\)位用户有\(x\)个通过的信息,可以视为\(pre_n-pre_{n-y}=x\),即\(pre_{n-y}=pre_n-x\)

如果我们知道\(pre_n\)的值,那么就只剩下了一些前缀和的信息了。所以我们可以直接枚举\(pre_n\)的值。这些关于前缀和的限制又将整个序列分割成了一些区间,每个区间的区间和也都被限制好了,直接使用组合数把每个区间的方案算出来就好了,答案就是每一个区间组合数的乘积。

但是上述的做法均不能处理\(x=y\)的情况,当\(x=y\)的时候,意味着有一个长度为\(x\)的前缀或后缀全都是\(1\)。这个\(x\)越大限制性必然越强,于是我们只需要考虑最大的\(x=y\),满足了最大的\(x=y\)剩下的\(x=y\)必然也都满足了。

我们枚举这个\(x=y\)限制前缀还是限制后缀,限制前缀就拆成\(pre_x=x\),限制后缀就拆成\(pre_{n-x}=x\)。但是如果有一种方案既有一段全是\(1\)的前缀,也有一段全是\(1\)的后缀,就会被计算两次。所以我们把两条限制条件都加上,再减掉这样的方案就好了。

代码

#include<bits/stdc++.h>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
inline int read() {
    char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
    while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int maxn=5e3+5;
const int mod=998244353;
int T,n,m,M;
int fac[maxn],ifac[maxn],inv[maxn];
int a[maxn],b[maxn],c[maxn],d[maxn],t[2],pre[maxn];
inline int C(int n,int m) {
    return m>n?0:1ll*fac[n]*ifac[m]%mod*ifac[n-m]%mod;
}
inline int solve(int sum) {
    for(re int i=1;i<=t[0];i++) {
        if(pre[a[i]]!=-1&&pre[a[i]]!=c[i]) return 0;
        pre[a[i]]=c[i];
    }
    for(re int i=1;i<=t[1];i++) {
        if(pre[n-b[i]]!=-1&&pre[n-b[i]]!=sum-d[i]) return 0;
        pre[n-b[i]]=sum-d[i]; 
    }
    if(pre[0]!=-1&&pre[0]!=0) return 0;
    pre[0]=0;int l=0,tot=1;
    for(re int i=1;i<=n;i++) {
        if(pre[i]==-1) continue;
        if(pre[i]<pre[l]) return 0;
        tot=1ll*tot*C(i-l,pre[i]-pre[l])%mod;l=i;
    }
    return tot;
}
int main() {
    T=read();fac[0]=ifac[0]=inv[1]=1;
    for(re int i=1;i<maxn;i++) fac[i]=1ll*fac[i-1]*i%mod;
    for(re int i=2;i<maxn;i++) inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
    for(re int i=1;i<maxn;i++) ifac[i]=1ll*ifac[i-1]*inv[i]%mod;
    while(T--) {
        n=read(),m=read();int x,y;t[0]=t[1]=M=0;
        for(re int i=1;i<=m;i++) {
            x=read(),y=read();
            if(x<y) b[++t[1]]=y,d[t[1]]=x;
            if(x>y) a[++t[0]]=x,c[t[0]]=y;
            if(x==y) M=max(M,x);
        }
        int ans=0,now=M;
        for(re int i=1;i<=t[0];i++) now=max(now,c[i]);
        for(re int i=1;i<=t[1];i++) now=max(now,d[i]);
        for(re int i=now;i<=n;i++) {
            memset(pre,-1,sizeof(pre));
            pre[n]=i,pre[M]=M;
            ans=(ans+solve(i))%mod;
            if(!M) continue;
            memset(pre,-1,sizeof(pre));
            pre[n]=i,pre[n-M]=i-M;
            ans=(ans+solve(i))%mod;
            memset(pre,-1,sizeof(pre));
            pre[n]=i,pre[M]=M;pre[n-M]=i-M;
            if(M==n-M&&M!=i-M) continue;
            ans=(ans-solve(i)+mod)%mod;
        }
        printf("%d\n",ans);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/asuldb/p/11391021.html