Concatenation with intersection(2800/z算法/树状数组/双指针)

题目:http://codeforces.com/contest/1313/problem/E
参考:http://codeforces.com/blog/entry/74146
题意:给定字符串a,b,s。求子串组合数使得 a [ l 1 , r 1 ] + b [ l 2 , r 2 ] = = s a[l1,r1]+b[l2,r2]==s ,要求 [ l 1 , r 1 ] , [ l 2 , r 2 ] [l1,r1],[l2,r2] 交集非空。
题解:用z算法求出a在s中的最长前缀 l c p lcp ,b在s中的最长后缀 l c s lcs 。从左到右枚举a字符串,取定 l 1 l_1 ,那么相应的 r 2 r_2 取值为 l 1 < = r 2 < = l 1 + m 2 l_1<=r_2<=l_1+m-2 ,我们把满足当前区间的所有 r 2 r_2 对应的最左端点 r 2 l c s r 2 r_2-lcs_{r_2} 都扔进树状数组上,分别统计数量与总和 c n t s u m cnt,sum
那么当前 l 1 l_1 为左端点的情况数有 l c s c n t s u m lcs*cnt-sum s u m sum 部分是取不到的情况,需要减去。这只是大概抽象的说明,代码实现和理论说明有出入,详见代码。

#include<bits/stdc++.h>
using namespace std;
const int maxn = 500010;
#define ll long long

int n,m;
char a[maxn],b[maxn];
char s[maxn*2],c[maxn*3];
int z[maxn*3];
int lcp[maxn];//lcp[i]表示a[i]开始,能匹配s的最长前缀 
int lcs[maxn];//lcs[i]表示b[i]开始,能匹配s的最长后缀 
//z算法,求解给定字符串每个位置能匹配自身的最长前缀 
void z_init(int len) {
	int l = 1,r = 1;
	z[1] = len;
	for(int i = 2;i <= len;i++) {
		if(i > r) {
			l = i;r = i;
			while(r<=len && c[r-i+1]==c[r]) r++;
			z[i] = r-l;r--;
		}else {
			int k = i-l+1;
			if(z[k]<r-i+1) z[i] = z[k];
			else {
				l = i;
				while(r<=len && c[r-i+1]==c[r]) r++;
				z[i] = r-l;r--;
			}
		}
	}
}
void init() {
	//求lcp 
	for(int i = 1;i <= m;i++) c[i] = s[i];
	c[m+1] = '#';
	for(int i = 1;i <= n;i++) c[m+1+i] = a[i];
	c[n+m+2] = '\0';
	z_init(n+m+1);
	for(int i = 1;i <= n;i++) lcp[i] = z[m+1+i];
	//求lcs
	for(int i = 1;i <= m;i++) c[i] = s[m+1-i];
	c[m+1] = '#';
	for(int i = 1;i <= n;i++) c[m+1+i] = b[n+1-i];
	c[n+m+2] = '\0';
	z_init(n+m+1);
	for(int i = 1;i <= n;i++) lcs[i] = z[m+1+n+1-i];
}
//Fenwick Tree
ll cnt[maxn*2],sum[maxn*2];
int lowbit(int x) {
	return x&(-x); 
} 
void add(int v) {
	int x = v;
	while(x <= n) sum[x]+=v,cnt[x]++,x+=lowbit(x);
}
void sub(int v) {
	int x = v;
	while(x <= n) sum[x]-=v,cnt[x]--,x+=lowbit(x);
}
ll get_sum(int x) {
	ll res = 0;
	while(x) res+=sum[x],x-=lowbit(x);
	return res;
}
ll get_cnt(int x) {
	ll res = 0;
	while(x) res+=cnt[x],x-=lowbit(x);
	return res;
} 
int main() {
	scanf("%d%d",&n,&m);
	scanf("%s%s%s",a+1,b+1,s+1);
	init();//puts("VE");
	ll ans = 0;
	/*
	*l1 <= r2 <= l1+m-2 as |r2-l1| <= m-1
	*initial l1 = 1
	*/
	for(int i = 1;i <= min(n,m-1);i++) add(max(1,m-lcs[i]));//puts("S");
	for(int i = 1,r;i <= n;i++) {
		r = min(m-1,lcp[i]);
		ans += 1LL*(r+1)*get_cnt(r)-get_sum(r);
		sub(max(1,m-lcs[i]));//delete case of "i as r2"
		if(i+m-1 <= n) add(max(1,m-lcs[i+m-1]));//add case of "i+m-1 as r2" 
	}
	printf("%I64d\n",ans); 
} 
发布了152 篇原创文章 · 获赞 2 · 访问量 6457

猜你喜欢

转载自blog.csdn.net/weixin_43918473/article/details/104655510