CF1042E Vasya and Magic Matrix 期望dp + 推公式

传送门

文章目录

题意:

在这里插入图片描述

思路:

将矩阵中的数放到数组里排序,就是一个比较明显的期望 d p dp dp了。
定义 f [ i ] f[i] f[i]表示从第 i i i个出发的期望得分,所以转移方程也比较好写了: f [ i ] = ∑ ( f [ j ] + ( x [ i ] − x [ j ] ) 2 + ( y [ i ] − y [ j ] ) 2 ) c n t f[i]=\frac{\sum(f[j]+(x[i]-x[j])^2+(y[i]-y[j])^2)}{cnt} f[i]=cnt(f[j]+(x[i]x[j])2+(y[i]y[j])2)
但是这样有个问题,如果直接转移的话那就是 O ( n 2 m 2 ) O(n^2m^2) O(n2m2)的了,看到平方这些东西我们会本能的拆开,所以我们考虑化简式子来进行 O ( 1 ) O(1) O(1)转移。
对于分母: ∑ f [ j ] + ( x [ i ] 2 + y [ i ] 2 ) ∗ c n t + ∑ ( x [ j ] 2 + y [ j ] 2 ) − 2 ∗ x [ i ] ∗ ∑ x [ j ] − 2 ∗ y [ i ] ∗ ∑ y [ j ] \sum f[j]+(x[i]^2+y[i]^2)*cnt+\sum(x[j]^2+y[j]^2)-2*x[i]*\sum x[j]-2*y[i]* \sum y[j] f[j]+(x[i]2+y[i]2)cnt+(x[j]2+y[j]2)2x[i]x[j]2y[i]y[j]
我们发现我们只需要维护 ∑ f [ j ] , ∑ ( x [ j ] 2 + y [ j ] 2 ) , ∑ x [ j ] , ∑ y [ j ] \sum f[j],\sum(x[j]^2+y[j]^2),\sum x[j],\sum y[j] f[j],(x[j]2+y[j]2),x[j],y[j]即可实现 O ( 1 ) O(1) O(1)转移了。
实现的时候需要等相同的值都算完了才能加上他们的贡献,且由于每个点的 x , y x,y x,y都不同,需要单独计算,一开始以为一样就在代码的 79 79 79行鬼使神差的写成了 f [ i ] = f [ i − 1 ] f[i]=f[i-1] f[i]=f[i1],真是老笨蛋啦。

//#pragma GCC optimize(2)
#include<cstdio>
#include<iostream>
#include<string>
#include<cstring>
#include<map>
#include<cmath>
#include<cctype>
#include<vector>
#include<set>
#include<queue>
#include<algorithm>
#include<sstream>
#include<ctime>
#include<cstdlib>
#define X first
#define Y second
#define L (u<<1)
#define R (u<<1|1)
#define pb push_back
#define mk make_pair
#define Mid (tr[u].l+tr[u].r>>1)
#define Len(u) (tr[u].r-tr[u].l+1)
#define random(a,b) ((a)+rand()%((b)-(a)+1))
#define db puts("---")
using namespace std;

//void rd_cre() { freopen("d://dp//data.txt","w",stdout); srand(time(NULL)); }
//void rd_ac() { freopen("d://dp//data.txt","r",stdin); freopen("d://dp//AC.txt","w",stdout); }
//void rd_wa() { freopen("d://dp//data.txt","r",stdin); freopen("d://dp//WA.txt","w",stdout); }

typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int,int> PII;

const int N=1000010,mod=998244353,INF=0x3f3f3f3f;
const double eps=1e-6;

int n,m,tot;
LL f[N],pref,prex,prey,prex2,prey2,pre;
struct Node
{
    
    
    LL x,y,val;
    bool operator < (const Node &w) const
    {
    
    
        return val<w.val;
    }
}a[N];

LL qmi(LL a,LL b)
{
    
    
    LL ans=1;
    while(b)
    {
    
    
        if(b&1) ans=ans*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return ans%mod;
}

int main()
{
    
    
//	ios::sync_with_stdio(false);
//	cin.tie(0);

    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) for(int j=1;j<=m;j++) {
    
     LL x; scanf("%lld",&x); a[++tot]={
    
    i,j,x}; }
    sort(a+1,a+1+tot);
    int cnt=0; a[0].val=-1;
    for(int i=2;i<=tot;i++)
    {
    
    
        if(a[i].val!=a[i-1].val)
        {
    
    
            int pos=i-1,val=a[i-1].val;
            while(pos>=1&&a[pos].val==val) (prex+=a[pos].x)%=mod,(prey+=a[pos].y)%=mod,(prex2+=a[pos].x*a[pos].x)%=mod,(prey2+=a[pos].y*a[pos].y)%=mod,(pre+=f[pos])%=mod,pos--,cnt++;
            f[i]=((pre+(a[i].x*a[i].x+a[i].y*a[i].y)*cnt+prex2+prey2-2*a[i].x*prex-2*a[i].y*prey)%mod+mod)%mod*qmi(cnt,mod-2)%mod;
        }
        else f[i]=((pre+(a[i].x*a[i].x+a[i].y*a[i].y)*cnt+prex2+prey2-2*a[i].x*prex-2*a[i].y*prey)%mod+mod)%mod*qmi(cnt,mod-2)%mod;
    }
    LL ans=-1;
    int sx,sy; cin>>sx>>sy;
    for(int i=1;i<=tot;i++) if(a[i].x==sx&&a[i].y==sy) {
    
     ans=f[i]; break; }
    cout<<ans%mod<<endl;


	return 0;
}
/*
1 4
1 1 2 1
1 3
*/