图文详解KMP算法(C++,Java实现)

问题提出:

给定一段长度为n的文本字符串txt和一段长度为m的模式字符串pat,在文本字符串txt中找到一个和该模式字符串pat相同的子字符串。找到则返回匹配的起始位置,没有返回-1。

解决方法:

  • 朴素字符串匹配算法
  • KMP算法

朴素字符串匹配算法开始

朴素字符串匹配算法:针对pat与txt,使用指针i跟踪txt,使用指针j跟踪pat。对于每个i,首先将j重置为0并不断将他增大,直至找到了一个不匹配的字符或者模式结束(j==m)为止,伪代码如下:

// search
 1 m=pat.length
 2 n=txt.length
 3 for i=0 to n-m
 4     for j=0 to m-1
 5        if txt[i+j] != pat[j]
 6            break
 7     if j==m return i
 8 return -1

在上述匹配的过程中,当pat的j位置与txt的i+j位置不匹配时,i+j回退到i+1位置,j回退到0位置,因此在最坏的情况下,每次都比较到pat的最后一个位置才发现不匹配,那么时间复杂度就是O(nm),匹配过程简化如下:

初始:

第1次比较完,未命中:(j从0开始比较到3,i从0开始比较到3)

第2次比较完,未命中:(j从0开始比较到3,i从1开始比较到4)

第3次比较完,未命中:(j从0开始比较到3,i从2开始比较到5)

后续未命中方式都是一致的,直接展示最后一步...

最后1次比较完,命中,返回9(角标从0计算):(j从0开始比较到3,i从9开始比较到12)

当回退前(j >0)的时,我么知道pat的pat[0...j-1]与txt[i...i+j-1]是匹配的,那么是否可以利用这一点,并通过某个数组帮助我们计算在这种情况下具体需要回退多少,而不是让j一下子回退到0,也就是避免“搜索位置移过已经比较过且匹配的位置”,从而保证提高比较的效率。实际上,这种算法已经被提出和数学证明,它就是KMP算法。

KMP算法

由Knuth,Morris和Pratt三人于1977年联合发表,将时间复杂度降为O(n+m)

算法思路:

  • 构造模式串pat的next[]数组,记录最少回退下的回退下标位置,使得匹配串txt在匹配过程中角标不断后移,避免txt串回退造成低效率;

  • kmp在next[]数组帮助下完成,具体参见kmp算法图解;

  • next[]数组是kmp完成的一大关键,其效率也会影响最终算法的整体效率,使用以下定义(第一张图片),在实现时使用dp思想,更精确来说是有限状态机思想,结合跳跃回退建成,时间复杂度为O(m),m时模式串pattern的长度。

pat串next的数组的定义与建立过程

next数组的定义:

对于上述的pat串,建立next数组,过程如下:

过程1:index等于1时,没有这样的k存在,next[0] = -1;

过程2:index等于2时,pat[0...0] = pat[1...1],因此k=0,即next[1] = 0;

过程3:index等于3时,pat[0...1] = pat[1...2],因此k=1,即next[3] = 1;

过程4:index等于4时,没有这样的k存在,next[3] = -1;

最终:

是不是看着挺简单的,留个难一点的例子pat = "ababbabbabbabab",留作练习(答案如下)。

那么如何通过代码来高效的构建next数组

1、暴力枚举法构建next数组

很明显,如果采用最简单的暴力枚举方法建立,时间复杂度是O(m^3),如下图所示 

现在假设j处于最后一个位置(黑色箭头所示),如果每一步都匹配,但是匹配过程是逐步增加的,红色和蓝色箭头所示,因此对于位置j,需要比较的次数最多是1+2+3+...+j-1=j*(j-1)/2,即O(j^2)。而我们的j位置(黑色箭头的取值范围0-m), 因此最坏情况大致是1^2+2^2+...+m^2 = m(m+1)(2m+1)/6,即O(m^3),显然通过暴力枚举方式是非常低效的。

2、利用状态转移机制高效建立next数组

正如我们前面所说,我们要利用前面匹配的串的状态记录来推导计算后面位置的值,状态求解从j=0开始,现在给出一般情况,要求的位置j状态处于较后的位置,如下图所示:

这就一种有限状态转移机制,求解当前状态,利用上一个状态,上一个状态能确定本次状态,完成状态计算,否则,利用上一个状态的前一个状态,看能否确定本次状态,直到没有前一个状态可是使用。

伪代码如下:

// buildNext 
 1 m=pat.size()
 2 let next[0...m] be a new array and be initialized to -1
 3 k=0
 4 for j=1 to m-1
 5    k=next[j-1]
 6    while k>=0 and pat[k+1]≠pat[j]
 7        k=next[k]
 8    if pat[k+1]==pat[j]
 9        next[j]=k+1
10 return next

注意:以上伪代码有个for循环和一个while循环,理论上时间复杂度是O(m^2),但不完全准确,大欧表示法是上界表示法,只能说正确,但不够精确。分析下,这里是回退,对于每一个回退,k的回退次数不会高于k的增长次数,k的最大增长次数不会大于m,因为这里是有限状态机制,本次k影响下一次的回退次数,这里说个极端最差情况,最后k为0,针对增加m次减小m次,一增一减,那么最终循环次数是2*m。在算法导论中使用摊还分析的聚合方法进行分析得出时间复杂度Θ(m)。

kmp算法search算法

当前访问的txt串和pat串对应位置txt[i]与pat[j]匹配,i和j均向前移1位,进行下一轮比较;

当前访问的txt串和pat串对应位置txt[i]与pat[j]不匹配且j大于0(即j可以回退),利用next数组确定j的回退位置,然后进行下一轮比较;

当上述两者均不满足,说明txt[i] != pat[0],直接将i向前移动1位,进行下一轮比较。

示意图如下:

伪代码如下:

// search
// txt pat
 1 n=txt.size, m=pat.txt, i=0, j=0
 2 while i<n and j<m
 3     if txt[i]==pat[j]
 4         i++, j++
 5     else if j>0
 6         j=next[j-1]+1
 7     else i++
 8 return j==m ? i-m : -1

时间复杂度的分析同样采用摊还分析的聚合方法,为Θ(m)。

案例模拟:

接下来使用上述kmp算法的search函数和next数组模拟下开头的例子

初始状态:

第1次比较完,未命中:(j从0开始比较到3,i从0开始比较到3)

利用next数组进行回退,下一个j = next[3 - 1] + 1 = 2

第2次比较完,未命中:(j从2开始比较到3,i从3开始比较到4)

利用next数组进行回退,下一个j = next[3 - 1] + 1 = 2

第3次比较完,未命中:(j从2开始比较到3,i从4开始比较到5)

利用next数组进行回退,下一个j = next[3 - 1] + 1 = 2

...

最后1次比较完,命中,返回9:(j从2开始比较到3,i从11开始比较到12)

完整的代码C++版本

#include <iostream>
#include <vector>
using namespace std;

// liyang  2019-12-12
// update: 2020-09-14

class KMP {
public:
    // 在文本字符串txt中寻找模式字符串pat
    // 找到:返回第一次匹配时pat在txt中的角标位置
    // 没找到:返回-1
    static int indexOf(const string txt, const string pat) {
        int n = txt.size(); int m = pat.size();
        if (m <= 0) {
            cout << "模式字符串pat输入不合法!" << endl;
            return -1;
        }
        if (n < m) {
            cout << "文本字符串txt的长度小于模式字符串的长度,不合法!" << endl;
            return -1;
        }
        vector<int> next = buildNext(pat, m);
        return search(txt, pat, next);
    }
private:
    // 创建pat字符串的next数组
    static vector<int> buildNext(const string& pat, const int& m) {
        vector<int> next(m, -1);
        for (int j = 1, k; j < m; j++) {
            k = next[j - 1];
            while (k >= 0 && pat[k + 1] != pat[j]) {
                k = next[k];
            }
            if (pat[k + 1] == pat[j]) {
                next[j] = k + 1;
            }
        }
        return next;
    }
    // 利用next数组进行真正的查找操作
    static int search(const string& txt, const string& pat, const vector<int>& next) {
        int n = txt.size(); int m = pat.size();
        int i = 0, j = 0;
        while (i < n && j < m) {
            if (txt[i] == pat[j]) { // 匹配
                i++; j++;
            }
            else if (j > 0) {       // 不匹配,j可以回退
                j = next[j - 1] + 1;
            }
            else {                  // 不匹配,j不可以回退(j==0),只能i前进
                i++;
            }
        }
        return j == m ? i - m : -1;
    }
};

int main(int argc, const char* argv[]) {
    string txt = "aaaaaaaaaaab";
    string pat = "aab";
    int index = KMP::indexOf(txt, pat);
    int n = txt.size();
    cout << index << ": " << (n - 3 == index ? "correct" : "wrong") << endl;
    system("pause");
    return 0;
}
9: correct
请按任意键继续. . .

完整版的Java代码

package com.ly.kmp;

/**
 * @author Young  2020-09-14
 * update:2020-09-15
 */

public class KMP {
    public static int indexOf(String txt, String pat) {
        int n = txt.length(); int m = pat.length();
        if (n < m || m == 0) {
            System.out.println("字符串输入有问题,请核对后重新输入!");
            return -1;
        }
        int[] next = buildNext(pat);
        return search(txt, pat, next);
    }

    private static int[] buildNext(String pat) {
        int m = pat.length();
        int[] next = new int[m];
        next[0] = -1;
        int k;
        for (int j = 1; j < m; j++) {
            k = next[j - 1];
            while (k >= 0 && pat.charAt(k + 1) != pat.charAt(j)) {
                k = next[k];
            }
            if (pat.charAt(k + 1) == pat.charAt(j)) {
                next[j] = k + 1;
            } else {
                next[j] = -1;
            }
        }
        return next;
    }

    private static int search(String txt, String pat, int[] next) {
        int n = txt.length(); int m = pat.length();
        int i = 0; int j = 0;
        while (i < n && j < m) {
            if (txt.charAt(i) == pat.charAt(j)) {
                i++; j++;
            } else if (j > 0) {
                j = next[j - 1] + 1;
            } else {
                i++;
            }
        }
        return j == m ? i - m : -1;
    }

    public static void main(String[] args) {
        String txt = "aaaaaaaaaaaab";
        String pat = "aaab";
        int index = KMP.indexOf(txt, pat);
        int n = txt.length();
        System.out.println(index + ": " + ((n - 4) == index ? "correct" : "wrong"));
    }
}
9: correct

Process finished with exit code 0

参考资料:

浙大数据结构

算法导论

算法4

猜你喜欢

转载自blog.csdn.net/weixin_41876385/article/details/103654084