BZOJ1396: 识别子串 SAM+线段树

版权声明:本文是蒟蒻写的,转载。。。随便吧 https://blog.csdn.net/xgc_woker/article/details/88048484

Description
对于一个字符串 S S ,一个位置 x x 的识别子串 T = S ( i , j ) T=S(i,j) 为:
1. i < = x < = j 1.i<=x<=j
2. T S 2.T只在S中出现过一次
对每个位置求出识别子串的长度。


Sample Input
agoodcookcooksgoodfood


Sample Output
1
2
3
3
2
2
3
3
2
2
3
3
2
1
2
3
3
2
1
2
3
4


S A M SAM S A SA 区别不大,作用都是求出一个后缀 l l 有重复的前缀的最长长度 m x mx
对于 [ l + m x , n ] [l+mx,n] 区间内的点贡献为 i l + 1 i-l+1
对于 [ l , l + m x 1 ] [l,l+mx-1] 区间内的点的贡献为 m x mx ,线段树维护即可。


#include <ctime>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>

using namespace std;
typedef long long LL;
int _max(int x, int y) {return x > y ? x : y;}
int _min(int x, int y) {return x < y ? x : y;}
int read() {
	int s = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
	while(ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
	return s * f;
}

struct tnode {
	int l, r, lc, rc, c, lazy, g, lag;
} tr[200010]; int tot;
struct SAM {
	int id, len, fa, son[26];
} t[200010]; int cnt, root, last;
int R[100010], s[100010], id[200010];
char ss[100010];
int n;

void insam(int c, int hh) {
	int np = ++cnt, p = last;
	t[np].len = t[p].len + 1; t[np].id = hh;
	while(p && !t[p].son[c]) t[p].son[c] = np, p = t[p].fa;
	if(p == 0) t[np].fa = 1;
	else {
		int q = t[p].son[c];
		if(t[p].len + 1 == t[q].len) t[np].fa = q;
		else {
			int nq = ++cnt;
			t[nq] = t[q];
			t[nq].len = t[p].len + 1;
			t[np].fa = t[q].fa = nq;
			while(p && t[p].son[c] == q) t[p].son[c] = nq, p = t[p].fa;
		}
	} last = np;
}

void bt_SAM() {
	last = cnt = 1;
	for(int i = 1; i <= n; i++) insam(ss[i] - 'a', i);
}

void getr() {
	for(int i = 1; i <= cnt; i++) s[t[i].len]++;
	for(int i = 1; i <= n; i++) s[i] += s[i - 1];
	for(int i = cnt; i >= 1; i--) id[s[t[i].len]--] = i;
	for(int p = 1, i = 1; i <= n; i++) p = t[p].son[ss[i] - 'a'], R[p] = 1;
	for(int i = cnt; i >= 1; i--) R[t[id[i]].fa] += R[id[i]], t[t[id[i]].fa].id = _max(t[t[id[i]].fa].id, t[id[i]].id);
}

void bt(int l, int r) {
	int now = ++tot;
	tr[now].l = l, tr[now].r = r;
	tr[now].lc = tr[now].rc = -1;
	tr[now].c = tr[now].lazy = tr[now].g = tr[now].lag = 999999999;
	if(l < r) {
		int mid = (l + r) / 2;
		tr[now].lc = tot + 1; bt(l, mid);
		tr[now].rc = tot + 1; bt(mid + 1, r);
	}
}

void update(int now) {
	if(tr[now].lazy == 999999999 && tr[now].lag == 999999999) return ;
	int lc = tr[now].lc, rc = tr[now].rc;
	tr[lc].c = _min(tr[lc].c, tr[now].lazy);
	tr[rc].c = _min(tr[rc].c, tr[now].lazy);
	tr[lc].g = _min(tr[lc].g, tr[now].lag);
	tr[rc].g = _min(tr[rc].g, tr[now].lag);
	tr[lc].lazy = _min(tr[lc].lazy, tr[now].lazy);
	tr[rc].lazy = _min(tr[rc].lazy, tr[now].lazy);
	tr[lc].lag = _min(tr[lc].lag, tr[now].lag);
	tr[rc].lag = _min(tr[rc].lag, tr[now].lag);
	tr[now].lazy = tr[now].lag = 999999999;
}

void change(int now, int l, int r, int c) {
	if(tr[now].l == l && tr[now].r == r) {
		tr[now].c = _min(tr[now].c, c);
		tr[now].lazy = _min(tr[now].lazy, c);
		return ;
	} update(now); int mid = (tr[now].l + tr[now].r) / 2;
	if(r <= mid) change(tr[now].lc, l, r, c);
	else if(l > mid) change(tr[now].rc, l, r, c);
	else change(tr[now].lc, l, mid, c), change(tr[now].rc, mid + 1, r, c);
}

void clg(int now, int l, int r, int c) {
	if(tr[now].l == l && tr[now].r == r) {
		tr[now].g = _min(tr[now].g, c);
		tr[now].lag = _min(tr[now].lag, c);
		return ;
	} update(now); int mid = (tr[now].l + tr[now].r) / 2;
	if(r <= mid) clg(tr[now].lc, l, r, c);
	else if(l > mid) clg(tr[now].rc, l, r, c);
	else clg(tr[now].lc, l, mid, c), clg(tr[now].rc, mid + 1, r, c);
}

int findm(int now, int p) {
	if(tr[now].l == tr[now].r) return _min(tr[now].g - tr[now].l + 1, tr[now].c);
	update(now); int mid = (tr[now].l + tr[now].r) / 2;
	return p <= mid ? findm(tr[now].lc, p) : findm(tr[now].rc, p);
}

int main() {
	scanf("%s", ss + 1);
	n = strlen(ss + 1);
	bt_SAM();
	getr();
	bt(1, n);
	for(int i = 1; i <= cnt; i++) if(R[i] == 1){
		int fa = t[i].fa, st = t[i].id - t[i].len + 1;
		change(1, t[i].id - t[fa].len, t[i].id, t[fa].len + 1);
		clg(1, 1, t[i].id - t[fa].len, t[i].id);
	} for(int i = 1; i <= n; i++) 
	printf("%d\n", findm(1, i));
	return 0;
}

猜你喜欢

转载自blog.csdn.net/xgc_woker/article/details/88048484