haar_lib/algo/
kmp.rs

1//! Knuth-Morris-Pratt法
2
3/// Knuth-Morris-Pratt法
4#[derive(Clone, Debug)]
5pub struct KMP<T> {
6    pat: Vec<Option<T>>,
7    table: Vec<isize>,
8}
9
10impl<T: PartialEq> KMP<T> {
11    /// **Time complexity O(|pat|)**
12    pub fn new(pat: Vec<T>) -> Self {
13        let m = pat.len();
14        let mut table: Vec<isize> = vec![0; m + 1];
15        table[0] = -1;
16
17        let mut pat: Vec<_> = pat.into_iter().map(|a| Some(a)).collect();
18        pat.push(None);
19
20        let mut i: usize = 2;
21        let mut j: usize = 0;
22        while i <= m {
23            if pat[i - 1] == pat[j] {
24                table[i] = (j + 1) as isize;
25                i += 1;
26                j += 1;
27            } else if j > 0 {
28                j = table[j] as usize;
29            } else {
30                table[i] = 0;
31                i += 1;
32            }
33        }
34
35        Self { pat, table }
36    }
37
38    /// **Time complexity O(|s|)**
39    pub fn matches(&self, s: &[T]) -> Vec<usize> {
40        let mut ret = vec![];
41        let n = s.len();
42
43        let mut m = 0;
44        let mut i = 0;
45        while m + i < n {
46            if self.pat[i].as_ref() == Some(&s[m + i]) {
47                i += 1;
48                if i == self.pat.len() - 1 {
49                    ret.push(m);
50                    m += (i as isize - self.table[i]) as usize;
51                    if i > 0 {
52                        i = self.table[i] as usize;
53                    }
54                }
55            } else {
56                m += (i as isize - self.table[i]) as usize;
57                if i > 0 {
58                    i = self.table[i] as usize;
59                }
60            }
61        }
62
63        ret
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70    use test_case::test_case;
71
72    #[test_case("aaa", "aaaaaaaa")]
73    #[test_case("ab", "abracadabra")]
74    fn test(pat: &str, s: &str) {
75        let kmp = KMP::new(pat.as_bytes().to_owned());
76        let indices = kmp.matches(s.as_bytes());
77
78        assert_eq!(indices, bruteforce(pat, s));
79    }
80
81    fn bruteforce(pat: &str, s: &str) -> Vec<usize> {
82        let mut ret = vec![];
83        for i in 0..=s.len() - pat.len() {
84            if &s[i..i + pat.len()] == pat {
85                ret.push(i);
86            }
87        }
88        ret
89    }
90}