版权声明:本文为博主原创文章,转载请著名出处 http://blog.csdn.net/u013534123 https://blog.csdn.net/u013534123/article/details/81904126
大致题意:给你一个n*m的01矩阵,现在要让你每一行和每一列都去掉一个数字,而且要求相邻两行之间去掉数字的位置的绝对值要小于等于k。现在问你删除之后的矩形最多有几种。
首先,我们一行一行考虑,对于同一行,显然是看有多少个块,有多少个块就有多少个方案。然后对于整个矩阵来说,任意位置(i,j)可以从上一行的(j-k,j+k)之间转移过来。dp[i][j]表示不考虑重复的情况下,处理到第i行,且第i行删掉第j个位置的方案数。那么,显然根据之前说的转移区间,有转移方程:
但是,这里面会有重复,因为一块里面删掉任意一个都是一样的。所以我们考虑要删掉这些计算重复的。那么这重复的具体来说是多少呢?我们不妨设重复的为ss[i][j],表示第i行第j个位置,与其同一行的第j-1个位置重复的部分。我们考虑,两个相邻的位置j和j-1的相交区间是[j-k,j-1+k],于是我们这个ss[i][j]也要从上一行的这一个区间转移过来,当然了,还要保证j和j-1要在同一个块,否则j和j-1是不会有重复的。具体来说:
这样,我们就解决的这个问题,对于位置(i,j)来说,处理到第i行,且第i行删掉j的方案数就是dp[i][j]-ss[i][j]。
这个dp暴力的话,时间复杂度是O(N^3)的,但是这个显然是可以用前缀和优化一下。如此复杂度可以到O(N^2),然后为了节省空间,我的代码里面用了滚动数组。具体见代码:
#include<bits/stdc++.h>
#define LL long long
#define mod 998244353
#define pb push_back
#define lb lower_bound
#define ub upper_bound
#define INF 0x3f3f3f3f
#define sf(x) scanf("%d",&x)
#define sc(x,y,z) scanf("%d%d%d",&x,&y,&z)
#define clr(x,n) memset(x,0,sizeof(x[0])*(n+5))
#define IO ios::sync_with_stdio(0);cin.tie(0);cout.tie(0)
#define file(x) freopen(#x".in","r",stdin),freopen(#x".out","w",stdout)
using namespace std;
const int N = 2e3 + 10;
int ss[2][N],s[2][N];
char str[N][N];
int main()
{
int T; sf(T);
while(T--)
{
int n,m,k;
sc(n,m,k);
for(int i=1;i<=n;i++)
{
str[i][0]='9';
scanf("%s",str[i]+1);
}
int cur=1,pre=0;
clr(s[pre],m); clr(ss[pre],m);
for(int i=1;i<=m;i++)
{
s[0][i]=1; ss[0][i]=(str[1][i]==str[1][i-1]);
}
for(int i=2;i<=n;i++)
{
clr(ss[cur],m); clr(s[cur],m);
for(int j=1;j<=m;j++)
{
s[pre][j]=(s[pre][j-1]+s[pre][j])%mod;
ss[pre][j]=(ss[pre][j-1]+ss[pre][j])%mod;
}
for(int j=1;j<=m;j++)
{
int l=max(1,j-k),r=min(m,j+k);
s[cur][j]=(s[pre][r]-s[pre][l-1]+mod)%mod;
s[cur][j]=(s[cur][j]-(ss[pre][r]-ss[pre][l]+mod)%mod+mod)%mod;
if (str[i][j]!=str[i][j-1]) {ss[cur][j]=0;continue;}
r=min(m,j+k-1); ss[cur][j]=(s[pre][r]-s[pre][l-1]+mod)%mod;
ss[cur][j]=(ss[cur][j]-(ss[pre][r]-ss[pre][l]+mod)%mod+mod)%mod;
}
swap(cur,pre);
}
LL ans=s[pre][1]%mod;
for(int i=2;i<=m;i++)
ans=(ans+s[pre][i]-ss[pre][i]+mod)%mod;
printf("%lld\n",ans);
}
return 0;
}