万径人踪灭 HYSBZ - 3160

版权声明:本文为博主原创文章,未经博主允许必须转载。 https://blog.csdn.net/qq_35950004/article/details/85230685

求不连续回文子序列数量。
首先FFT求回文子序列数量,再用manachar算出连续回文子串数量。
完了。。。。。。

#define maxn 300005
#define LL long long
#define mod 1000000007
using namespace std;

const double Pi = 3.1415926535897932384626433832795;
int n,lgn,f[maxn];
char s[maxn];

struct cplx
{
	double r,i;
	cplx(double r=0,double i=0):r(r),i(i){}
	cplx operator +(const cplx &B)const{ return cplx(r+B.r,i+B.i); }
	cplx operator -(const cplx &B)const{ return cplx(r-B.r,i-B.i); }
	cplx operator *(const cplx &B)const{ return cplx(r*B.r-i*B.i,r*B.i+i*B.r); }
	cplx conj(){ return cplx(r,-i); }
}A[maxn],w[maxn]={1},B[maxn];
int r[maxn];

inline void FFT(cplx A[maxn],int lgn,int tp)
{
	int n = 1<<lgn;
	for(int i=0;i<n;i++) if(i < r[i]) swap(A[i],A[r[i]]);
	for(int L=2;L<=n;L<<=1)
	{
		int l = L>>1;w[1] = cplx(cos(Pi/l),sin(Pi/l)*tp);
		for(int i=2;i<l;i++) w[i] = w[i-1] * w[1];
		for(int st=0;st<n;st+=L)
			for(int k=0;k<l;k++)
			{
				cplx tmp = w[k] * A[st+k+l];
				A[st+k+l] = A[st+k] - tmp , A[st+k] = A[st+k] + tmp;
			}
	}
	if(tp==-1) 
		for(int i=0;i<n;i++)
			A[i].r/=n;
}

inline int Pow(int base,int k)
{
	int ret = 1;
	for(;k;k>>=1,base=1ll*base*base%mod) if(k&1) ret=1ll*ret*base%mod;
	return ret;
}

inline int manachar()
{
	static char ns[maxn]={};
	int tot = 0;
	ns[++tot] = '$';
	for(int i=0;i<n;i++)
		ns[++tot] = '*' , ns[++tot] = s[i];
	ns[++tot] = '*', ns[++tot] = '&';
	
	static int len[maxn] = {};
	int mx=0,o=0,ans=0;
	for(int i=0;i<tot;i++)
	{
		if(i<mx) len[i] = min(len[2*o-i],mx-i);
		else len[i]=1;
		for(;ns[i-len[i]]==ns[i+len[i]];len[i]++);
		if(i+len[i] > mx) mx = i + len[i] , o = i;
		ans = (ans + len[i]/2) % mod;
	}
	return ans;
}

int main()
{
	scanf("%s",s);
	n = strlen(s);
	for(lgn=0;2*n>=(1<<lgn);lgn++);
	for(int i=0,len=1<<lgn;i<len;i++) r[i] = (r[i>>1]>>1) + ((i&1)<<(lgn-1));
		
	for(int i=0;i<n;i++) 
		if(s[i] == 'a') 
			A[i].r=1;
	FFT(A,lgn,1);
	for(int i=0,len=1<<lgn;i<len;i++)
		B[i] = A[i] * A[i];
	FFT(B,lgn,-1);
	for(int i=0;i<2*n;i++) f[i] += int(B[i].r+1.5)/2;
		
	memset(A,0,sizeof A);
	for(int i=0;i<n;i++) if(s[i] == 'b') A[i].r=1;
	FFT(A,lgn,1);
	for(int i=0,len=1<<lgn;i<len;i++)B[i]=A[i]*A[i];
	FFT(B,lgn,-1);
	int ans = 0;
	for(int i=0;i<2*n;i++)
	{
		f[i] += int(B[i].r+1.5)/2;
	}
	
	for(int i=0;i<2*n;i++)
	{
		f[i] = Pow(2,f[i]) - 1;
		ans = (f[i] + ans) % mod;
	}
	printf("%d\n",(ans - manachar()+mod)%mod);
}

猜你喜欢

转载自blog.csdn.net/qq_35950004/article/details/85230685