字典树, kmp, AC自动机

最近学习了AC自动机, 所以就想总结一下自己的心得

学习AC自动机之前要先学会字典树和kmp, 这里我都会简单提一下。

  字典树, 字典树恰如其名, 真的是和字典非常的类似, 当我们在字典上查找单词的时候都是先找第一个字母确定大概范围,然后再找其它字母逐渐缩小围最后找到单词或者发现字典上没有该单词。 字典树也是这样, 一个个的插入或查找字母, 逐层向下插入或者查找字母, 在插入时, 如果没找到某个字母, 为其新增节点, 查询时, 如果没找到该字母, 则返回0, 表明没有出现过该字符串。 下面代码说明:

  

 1 #include<bits/stdc++.h>
 2 
 3 using namespace std;
 4 
 5 ///maxNodes是最大的节点数, 一般设置为字符串最大数量乘以字符串最大长度
 6 
 7 ///type是字符的种类, 如果只有大写或者小写字母, 可以设置为26
 8 
 9 const int maxNodes = 100000 + 20;
10 
11 const int type = 128;
12 
13 ///trie就是字典树, endWithNode[i]表示以节点i结尾的单词的数量
14 
15 ///当然endWithNode[i]还可以有其他的含义
16 
17 int trie[maxNodes][type], endWithNode[maxNodes];
18 
19 ///root是根节点, 一般是0或1, numberOfNodes是节点的数量
20 
21 ///unusedMark用来标记还没有用过的节点
22 
23 int root, numberOfNodes, unusedMark;
24 
25 ///newNode用来申请新的节点, 申请之前先初始化该节点
26 
27 int newNode() {
28     
29     for (int i = 0; i < type; i++)
30         
31         trie[numberOfNodes][i] = unusedMark;
32     
33     endWithNode[numberOfNodes] = 0;
34     
35     return numberOfNodes++;
36     
37 }
38 
39 ///字典树的初始化, 并不需要把所有的节点都初始化, 节点等到要用的在初始化
40 
41 void init() {
42     
43     numberOfNodes = 0;
44     
45     unusedMark = -1;
46     
47     root = newNode();
48     
49 }
50 
51 ///插入字符串
52 
53 void Insert(char buf[]) {
54     
55     int len = strlen(buf);
56     
57     int now = root;
58 
59     for (int i = 0; i < len; i++) {
60             
61         int x = buf[i];
62         
63         if (trie[now][x] == unusedMark) trie[now][x] = newNode();
64     
65         now = trie[now][x];
66         
67     }
68     
69     endWithNode[now]++;
70     
71 }
72 
73 ///查询字符串, 返回值为该字符串出现的次数
74 
75 int Query(char buf[]) {
76     
77     int len = strlen[buf];
78     
79     int now = root;
80 
81     for (int i = 0; i < len; i++) {
82             
83         int x = buf[i];
84     
85         if (trie[now][x] == unusedMark) return 0;
86         
87         now = trie[now][x];
88         
89     }
90     
91     return endWithNode[now];
92     
93 }

  然后是kmp, kmp简单来说就是关键字搜索

  常规的字符串搜索的时间复杂度为O (m * n), n为主串长度, m为模式串长度, 方法如下:设主串为s = asdcbfahnaof, 模式串为t = nao

 1 int len = s.length(), len2 = t.length(), i = 0, j = 0;
 2 
 3 while (i < len && j < len2) {
 4 
 5     if(s[i] == t[j]) {
 6 
 7         i++; j++;
 8 
 9     } else {
10 
11         i = i - j + 1;
12 
13         j = 0;
14 
15     }
16 
17 }
18 
19 return j == len2 ? i - j + 1 : -1;

这种方法, 众生平等, 无论主串和模式串是什么都是很慢的, 但是, 如果s = abcdabceabcdabcd, t = abcdabcd的话, 当i = 7, j = 7时, 就已经失配了, 此时如果是上面的做法, 就会有很多无意义的重复

j会一直为0, i会从1加到8, 此时才开始接下来有意义的操作, 而kmp算法则是让i 保持不变, 变的是j

 

 1 while (i < len && j < len2) {
 2 
 3     if(j == -1 || s[i] == t[j]) {
 4 
 5         i++;
 6 
 7         j++;
 8 
 9     } else {
10 
11         j = next[j];
12 
13     }
14 
15 }

就拿当前这种情况来说, j 会跳到3, 然后跳到-1, 为什么呢?因为, s 和 t是在s[7]这里不匹配的, 但是前面7个字符是匹配的, 因此, 如果能在t里面找到最长的一段后缀与相同长度的前缀相等, 那么就可以令j直接跳过去, 然后开始匹配了

还是这个例子, t[0 - 2] = t[5 - 7], t[5 - 7] 与 s[5 - 7] 相等, 那么t[0 - 2] 与 s[5 - 7] 也是相等的, 既然这样的话, 那我们就只需要比较 t[3] 和 s[8] 即可, 这样的话就可以节省很多时间, 那么关键来了, 我们如何知道j 在失配之后应该跳到哪个位置呢?

其实, 也很简单

 1 int i = 0, j = -1;
 2 
 3 int next[len2];
 4 
 5 next[0] = -1;
 6 
 7 while (i < len) {
 8 
 9     if (j == -1 || t[i] == t[j]) {
10 
11         j++;
12 
13         i++;
14 
15         next[i] = j;
16 
17     } else {
18 
19         j = next[j];
20 
21     }
22 
23 }

就拿t = abcdabce来说, a, b, c, d第一次出现, next[1 - 4] = 0, 这是当然的吧, 没有相同的前后缀当然要重新开始咯, i = 4, j = 0时, next[5] = 1, 因为t[0] 与 t[4]相等, 这个效果就是, 当t[5]失配时, j(不是求next数组时的j) 跳到1, 然后开始继续匹配, 然后嘛, next[6] = 2, next[7] = 3, next[8] = 0; next[8] = 0, 我们来好好分析一下, 当i = 6, j = 2时, t[6] = t[2], 所以, t[7] = 3, 但是, t[7] != t[3] && 3 != -1, 所以j = next[3] = 0, 下一个循环时, 0 != -1 && t[7] != t[0], 所以j = next[0] = -1, 再下一个循环, j = -1, 所以, next[8] = 0;接下来, j == len2, 循环结束, next数组求出来为next[0 - 8] = {-1, 0, 0, 0, 0, 1, 2, 3, 0};

设s = aaabaaabaaab, t = aaaa;那么next[0 - 4] = {-1, 0, 1, 2, 3}; 然后呢, 当i = 3, j = 3时, j就会依次从3减到 -1, 然后嘛, 到了i = 7, j = 3; i = 11, j = 3的时候都会这样无意义的重复, 很浪费时间的, 这个锅当然是next数组来背了, 而实际上, 也确实是next数组的求法出了点问题

 1 if (j == -1 || t[i] == t[j]) {
 2 
 3     if(t[++i] == t[++j]) {
 4 
 5         next[i] = next[j];
 6 
 7     } else {
 8 
 9         next[i] = j;
10 
11     }
12 
13 }

改一下这部分就好了, 分析:

  1. j = -1, i = 0 => next[1] = next[0] = -1;
  2. t[0] = t[1] = ‘a’=> next[2] = next[1] = -1;
  3. t[1] = t[2] = ‘a’ => next[3] = next[2] = -1;
  4. t[2] = t[3] = ‘a’, 但是t[4] = ‘\0’ = 0, t[3] = ‘a’, t[4] = j = 3;

回看上面, j是不会等len2的, 所以, t[len2]也可以不用求, 但是强迫症看着难受, 而且, 在某些时候, t[len2]是有作用的, 所以我还是求了。

下面是kmp的完整代码

 

 1 #include<bits/stdc++.h>
 2 
 3 using namespace std;
 4 
 5 int* GetNext(string t) {
 6 
 7     int len2 = t.length(), i = 0, j = -1;
 8 
 9     int *next = new int[len2 + 1];
10 
11     next[0] = -1;
12 
13     while (i < len2) {
14 
15         if (j == -1 || t[i] == t[j]) {
16 
17             if (t[++i] == t[++j]) {
18 
19                 next[i] = next[j];
20 
21             } else {
22 
23                 next[i] = j;
24 
25             }
26 
27         } else {
28 
29             j = next[j];
30 
31         }
32 
33     }
34 
35     return next;
36 
37 }
38 
39 int KMP(string s, string t) {
40 
41     int len = s.length(), len2 = t.length(), i = 0, j = 0;
42 
43     int *next = GetNext(t);
44 
45     while (i < len && j < len2) {
46 
47         if(j == -1 || s[i] == t[j]) {
48 
49             i++;
50 
51             j++;
52 
53         } else {
54 
55             j = next[j];
56 
57         }
58 
59     }
60 
61     return j == len2 ? i - j + 1 : -1;
62 
63 }
64 
65 int main() {
66 
67     string a, b;
68 
69     while (cin >> a >> b) {
70 
71         cout << KMP(a, b) << endl;
72 
73     }
74 
75     return 0;
76 
77 }

 

  字典树和kmp讲完之后, 就是AC自动机了, AC自动机其实就是在字典树上跑kmp(是不是不知道我说的是什么意思, 没错, 我一开始也是迷迷糊糊的, 完全不知道在说啥), 作用就是同时对n个模式串进行搜索, 不用想也知道肯定比进行n次kmp要快

  AC自动机分为三步:

  第一步、 对n个模式串建立一颗字典树

  第二步、 找到字典树上每个节点的fail指针(其实就是next数组)

  第三步、 喜闻乐见的搜索环节

    

  事前准备就是这么多, 建议写成结构体, 不知为何, 我感觉结构体会比较快, 可能是我的错觉吧

 

  下面是喜闻乐见的代码环节:

  第一步、 建树

  

 1     int newNode() {
 2         
 3         for (int i = 0; i < type; i++)  trie[numberOfNodes][i] = unusedMark;
 4         
 5         tail[numberOfNodes] = unusedMark;
 6         
 7         return numberOfNodes++;
 8         
 9     }
10 
11     void init() {
12         
13         numberOfNodes = 0;
14         
15         unusedMark = -1;
16         
17         root = newNode();
18         
19     }
20 
21     void Insert(char s[]) {
22         
23         int len = strlen(s), now = root;
24 
25         for (int i = 0; i < len; i++) {
26                 
27             int x = s[i] - 'a';
28         
29             if (trie[now][x] == unusedMark) trie[now][x] = newNode();
30             
31             now = trie[now][x];
32             
33         }
34         
35         if (tail[now] == unusedMark)    tail[now] = 0;
36         
37         tail[now]++;
38         
39     }

  这一步和之前是一样的

  第二步, 构建fail指针, 找fail指针就两步, 先把root节点的子节点的fail指针指向root, 然后按BFS的方法遍历其他节点, 其他节点的fail指针的找寻方法有点绕, 我的舌头捋不直, 大佬们的说法是(设这个节点上的字母为C,沿着他父亲的失败指针走,直到走到一个节点,他的儿子中也有字母为C的节点。然后把当前节点的失败指针指向那个字母也为C的儿子。如果一直走到了root都没找到,那就把失败指针指向root), 读者如果无法理解的话, 请尝试看代码领悟

  

 1     void BuildFail() {
 2         
 3         queue<int> Q;
 4         
 5         fail[root] = root;
 6 
 7         ///把root节点的子节点的fail指针指向root
 8         for (int i = 0; i < type; i++)
 9             
10             if (trie[root][i] == unusedMark)    trie[root][i] = root;
11         
12             else {
13                     
14                 fail[trie[root][i]] = root;
15             
16                 Q.push(trie[root][i]);
17                 
18             }
19 
20         while (!Q.empty()) {
21                 
22             int now = Q.front();    Q.pop();
23 
24             for (int i = 0; i < type; i++)
25                 
26                 if (trie[now][i] == unusedMark) trie[now][i] = trie[fail[now]][i];
27             
28                 else {
29                         
30                     fail[trie[now][i]] = trie[fail[now]][i];
31                 
32                     Q.push(trie[now][i]);
33                     
34                 }
35                 
36         }
37         
38     }

  第三步、 搜索环节

  

 1 int Query(char buf[]) {
 2         
 3         int len = strlen(buf), now = root, ans = 0;
 4 
 5         for (int i = 0; i < len; i++) {
 6                 
 7             now = trie[now][buf[i] - 'a'];
 8         
 9             int temp = now;
10             
11             ///让temp沿着fail指针一直向上走, 并更新ans, 注意每个节点的tail值只能加一次
12             while (tail[temp] != unusedMark) {
13                     
14                 ans += tail[temp];
15             
16                 tail[temp] = unusedMark;
17                 
18                 temp = fail[temp];
19                 
20             }
21             
22         }
23 
24         return ans;
25     }

  下面给出完整代码(hdu - 2222的ac代码)

  

  1 #include<bits/stdc++.h>
  2 
  3 using namespace std;
  4 
  5 const int N = 10000 * 50 + 5;
  6 
  7 const int type = 26;
  8 
  9 struct AC_Automation {
 10     
 11     int trie[N][type], fail[N], tail[N];
 12     
 13     int root, unusedMark, numberOfNodes;
 14 
 15     int newNode() {
 16 
 17         for (int i = 0; i < type; i++)  trie[numberOfNodes][i] = unusedMark;
 18 
 19         tail[numberOfNodes] = unusedMark;
 20 
 21         return numberOfNodes++;
 22 
 23     }
 24 
 25     void init() {
 26 
 27         numberOfNodes = 0;
 28 
 29         unusedMark = -1;
 30 
 31         root = newNode();
 32 
 33     }
 34 
 35     void Insert(char s[]) {
 36 
 37         int len = strlen(s), now = root;
 38 
 39         for (int i = 0; i < len; i++) {
 40 
 41             int x = s[i] - 'a';
 42 
 43             if (trie[now][x] == unusedMark) trie[now][x] = newNode();
 44 
 45             now = trie[now][x];
 46 
 47         }
 48 
 49         if (tail[now] == unusedMark)    tail[now] = 0;
 50 
 51         tail[now]++;
 52 
 53     }
 54 
 55     void BuildFail() {
 56         
 57         queue<int> Q;
 58         
 59         fail[root] = root;
 60 
 61         ///把root节点的子节点的fail指针指向root
 62         for (int i = 0; i < type; i++)
 63             
 64             if (trie[root][i] == unusedMark)    trie[root][i] = root;
 65         
 66             else {
 67                     
 68                 fail[trie[root][i]] = root;
 69             
 70                 Q.push(trie[root][i]);
 71                 
 72             }
 73 
 74         while (!Q.empty()) {
 75                 
 76             int now = Q.front();    Q.pop();
 77 
 78             for (int i = 0; i < type; i++)
 79                 
 80                 if (trie[now][i] == unusedMark) trie[now][i] = trie[fail[now]][i];
 81             
 82                 else {
 83                         
 84                     fail[trie[now][i]] = trie[fail[now]][i];
 85                 
 86                     Q.push(trie[now][i]);
 87                     
 88                 }
 89                 
 90         }
 91         
 92     }
 93 
 94     int Query(char buf[]) {
 95         
 96         int len = strlen(buf), now = root, ans = 0;
 97 
 98         for (int i = 0; i < len; i++) {
 99                 
100             now = trie[now][buf[i] - 'a'];
101         
102             int temp = now;
103             
104             ///让temp沿着fail指针一直向上走, 并更新ans, 注意每个节点的tail值只能加一次
105             while (tail[temp] != unusedMark) {
106                     
107                 ans += tail[temp];
108             
109                 tail[temp] = unusedMark;
110                 
111                 temp = fail[temp];
112                 
113             }
114             
115         }
116 
117         return ans;
118     }
119     
120 } ac;
121 
122 char buf[1000005], s[55];
123 
124 int main() {
125     
126     int n, m;
127     
128     while (scanf("%d", &m) == 1) {
129             
130         while (m--) {
131             
132             ac.init();
133     
134             scanf("%d", &n);
135             
136             while (n--) {
137                     
138                 scanf("%s", s);
139             
140                 ac.Insert(s);
141                 
142             }
143             ac.BuildFail();
144             
145             scanf("%s", buf);
146             
147             printf("%d\n", ac.Query(buf));
148             
149         }
150         
151     }
152     
153     return 0;
154     
155 }

hdu - 2222的链接

UESTC - 1977的链接

猜你喜欢

转载自www.cnblogs.com/123zhh-helloworld/p/9758195.html
今日推荐