题意:
思路:
将矩阵中的数放到数组里排序,就是一个比较明显的期望 d p dp dp了。
定义 f [ i ] f[i] f[i]表示从第 i i i个出发的期望得分,所以转移方程也比较好写了: f [ i ] = ∑ ( f [ j ] + ( x [ i ] − x [ j ] ) 2 + ( y [ i ] − y [ j ] ) 2 ) c n t f[i]=\frac{\sum(f[j]+(x[i]-x[j])^2+(y[i]-y[j])^2)}{cnt} f[i]=cnt∑(f[j]+(x[i]−x[j])2+(y[i]−y[j])2)
但是这样有个问题,如果直接转移的话那就是 O ( n 2 m 2 ) O(n^2m^2) O(n2m2)的了,看到平方这些东西我们会本能的拆开,所以我们考虑化简式子来进行 O ( 1 ) O(1) O(1)转移。
对于分母: ∑ f [ j ] + ( x [ i ] 2 + y [ i ] 2 ) ∗ c n t + ∑ ( x [ j ] 2 + y [ j ] 2 ) − 2 ∗ x [ i ] ∗ ∑ x [ j ] − 2 ∗ y [ i ] ∗ ∑ y [ j ] \sum f[j]+(x[i]^2+y[i]^2)*cnt+\sum(x[j]^2+y[j]^2)-2*x[i]*\sum x[j]-2*y[i]* \sum y[j] ∑f[j]+(x[i]2+y[i]2)∗cnt+∑(x[j]2+y[j]2)−2∗x[i]∗∑x[j]−2∗y[i]∗∑y[j]
我们发现我们只需要维护 ∑ f [ j ] , ∑ ( x [ j ] 2 + y [ j ] 2 ) , ∑ x [ j ] , ∑ y [ j ] \sum f[j],\sum(x[j]^2+y[j]^2),\sum x[j],\sum y[j] ∑f[j],∑(x[j]2+y[j]2),∑x[j],∑y[j]即可实现 O ( 1 ) O(1) O(1)转移了。
实现的时候需要等相同的值都算完了才能加上他们的贡献,且由于每个点的 x , y x,y x,y都不同,需要单独计算,一开始以为一样就在代码的 79 79 79行鬼使神差的写成了 f [ i ] = f [ i − 1 ] f[i]=f[i-1] f[i]=f[i−1],真是老笨蛋啦。
//#pragma GCC optimize(2)
#include<cstdio>
#include<iostream>
#include<string>
#include<cstring>
#include<map>
#include<cmath>
#include<cctype>
#include<vector>
#include<set>
#include<queue>
#include<algorithm>
#include<sstream>
#include<ctime>
#include<cstdlib>
#define X first
#define Y second
#define L (u<<1)
#define R (u<<1|1)
#define pb push_back
#define mk make_pair
#define Mid (tr[u].l+tr[u].r>>1)
#define Len(u) (tr[u].r-tr[u].l+1)
#define random(a,b) ((a)+rand()%((b)-(a)+1))
#define db puts("---")
using namespace std;
//void rd_cre() { freopen("d://dp//data.txt","w",stdout); srand(time(NULL)); }
//void rd_ac() { freopen("d://dp//data.txt","r",stdin); freopen("d://dp//AC.txt","w",stdout); }
//void rd_wa() { freopen("d://dp//data.txt","r",stdin); freopen("d://dp//WA.txt","w",stdout); }
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int,int> PII;
const int N=1000010,mod=998244353,INF=0x3f3f3f3f;
const double eps=1e-6;
int n,m,tot;
LL f[N],pref,prex,prey,prex2,prey2,pre;
struct Node
{
LL x,y,val;
bool operator < (const Node &w) const
{
return val<w.val;
}
}a[N];
LL qmi(LL a,LL b)
{
LL ans=1;
while(b)
{
if(b&1) ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans%mod;
}
int main()
{
// ios::sync_with_stdio(false);
// cin.tie(0);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) for(int j=1;j<=m;j++) {
LL x; scanf("%lld",&x); a[++tot]={
i,j,x}; }
sort(a+1,a+1+tot);
int cnt=0; a[0].val=-1;
for(int i=2;i<=tot;i++)
{
if(a[i].val!=a[i-1].val)
{
int pos=i-1,val=a[i-1].val;
while(pos>=1&&a[pos].val==val) (prex+=a[pos].x)%=mod,(prey+=a[pos].y)%=mod,(prex2+=a[pos].x*a[pos].x)%=mod,(prey2+=a[pos].y*a[pos].y)%=mod,(pre+=f[pos])%=mod,pos--,cnt++;
f[i]=((pre+(a[i].x*a[i].x+a[i].y*a[i].y)*cnt+prex2+prey2-2*a[i].x*prex-2*a[i].y*prey)%mod+mod)%mod*qmi(cnt,mod-2)%mod;
}
else f[i]=((pre+(a[i].x*a[i].x+a[i].y*a[i].y)*cnt+prex2+prey2-2*a[i].x*prex-2*a[i].y*prey)%mod+mod)%mod*qmi(cnt,mod-2)%mod;
}
LL ans=-1;
int sx,sy; cin>>sx>>sy;
for(int i=1;i<=tot;i++) if(a[i].x==sx&&a[i].y==sy) {
ans=f[i]; break; }
cout<<ans%mod<<endl;
return 0;
}
/*
1 4
1 1 2 1
1 3
*/