haar_lib/ds/
aho_corasick.rs

1//! Aho-Corasick法
2//!
3//! # Problems
4//! - <https://yukicoder.me/problems/no/430>
5//! - <https://atcoder.jp/contests/abc362/tasks/abc362_g>
6//! - <https://atcoder.jp/contests/abc268/tasks/abc268_h>
7//! - <https://judge.yosupo.jp/problem/aho_corasick>
8use std::{
9    collections::{HashMap, VecDeque},
10    hash::Hash,
11};
12
13/// [`AhoCorasick`]のノード
14pub struct Node<K> {
15    index: usize,
16    parent: Option<*mut Self>,
17    children: HashMap<K, *mut Self>,
18    failure_link: Option<*mut Self>,
19    rev_failure_links: Vec<*mut Self>,
20}
21
22impl<K: Copy + Hash + Eq> Node<K> {
23    fn new(index: usize) -> Self {
24        Self {
25            index,
26            parent: None,
27            children: HashMap::new(),
28            failure_link: None,
29            rev_failure_links: vec![],
30        }
31    }
32
33    /// ノード毎に割り当てられた番号を返す。
34    pub fn index(&self) -> usize {
35        self.index
36    }
37
38    /// 文字`c`で遷移する子ノードへの参照を返す。
39    pub fn child(&self, c: K) -> Option<&Self> {
40        self.children.get(&c).map(|&p| {
41            assert!(!p.is_null());
42            unsafe { &*p }
43        })
44    }
45
46    /// すべての子ノードへの遷移文字と参照へのイテレータを返す。
47    pub fn children(&self) -> impl Iterator<Item = (&K, &Self)> {
48        self.children.iter().map(|(k, &v)| (k, unsafe { &*v }))
49    }
50
51    /// 親ノードへの参照を返す。
52    pub fn parent(&self) -> Option<&Self> {
53        self.parent.map(|p| unsafe { &*p })
54    }
55
56    /// 子ノードへ遷移できないときに辿るべきノードへの参照を返す。
57    pub fn failure_link(&self) -> Option<&Self> {
58        self.failure_link.map(|p| {
59            assert!(!p.is_null());
60            unsafe { &*p }
61        })
62    }
63
64    /// failure_linkを逆に辿ったノードへの参照へのイテレータを返す。
65    pub fn rev_failure_links(&self) -> impl Iterator<Item = &Self> {
66        self.rev_failure_links.iter().map(|&p| {
67            assert!(!p.is_null());
68            unsafe { &*p }
69        })
70    }
71}
72
73fn index_of<K>(p: *mut Node<K>) -> usize {
74    assert!(!p.is_null());
75    unsafe { (*p).index }
76}
77
78fn child_of<K: Hash + Eq>(p: *mut Node<K>, c: K) -> Option<*mut Node<K>> {
79    assert!(!p.is_null());
80    unsafe { (*p).children.get(&c).copied() }
81}
82
83fn failure_link_of<K>(p: *mut Node<K>) -> Option<*mut Node<K>> {
84    assert!(!p.is_null());
85    unsafe { (*p).failure_link }
86}
87
88fn set_failure_link<K>(from: *mut Node<K>, to: *mut Node<K>) {
89    assert!(!from.is_null());
90    unsafe {
91        (*from).failure_link = Some(to);
92        (*to).rev_failure_links.push(from);
93    }
94}
95
96/// [`AhoCorasick`]を構築するための構造体。
97pub struct AhoCorasickBuilder<K> {
98    size: usize,
99    root: *mut Node<K>,
100    dict: Vec<Vec<K>>,
101    dict_index: Vec<Vec<usize>>,
102    nodes: Vec<*mut Node<K>>,
103}
104
105#[allow(clippy::new_without_default)]
106impl<K: Copy + Hash + Eq> AhoCorasickBuilder<K> {
107    /// [`AhoCorasickBuilder`]を生成する。
108    pub fn new() -> Self {
109        let root = Box::new(Node::new(0));
110        let root = Box::into_raw(root);
111
112        Self {
113            size: 1,
114            root,
115            dict: vec![],
116            dict_index: vec![],
117            nodes: vec![],
118        }
119    }
120
121    /// パターン`pat`を追加する。
122    pub fn add<I>(&mut self, pat: I)
123    where
124        I: IntoIterator<Item = K>,
125    {
126        let pat = pat.into_iter().collect::<Vec<_>>();
127        self.dict.push(pat.clone());
128
129        let mut cur = self.root;
130
131        for c in pat {
132            assert!(!cur.is_null());
133            if let Some(next) = child_of(cur, c) {
134                cur = next;
135            } else {
136                let new = Box::new(Node::new(self.size));
137                let new = Box::into_raw(new);
138
139                assert!(!cur.is_null());
140                unsafe { (*cur).children.insert(c, new) };
141                unsafe { (*new).parent = Some(cur) };
142
143                cur = new;
144                self.size += 1;
145            }
146        }
147
148        self.nodes.push(cur);
149
150        self.dict_index.resize(self.size, vec![]);
151        self.dict_index[index_of(cur)].push(self.dict.len() - 1);
152    }
153
154    /// [`AhoCorasick`]を構築する。
155    pub fn build(self) -> AhoCorasick<K> {
156        let mut dq = VecDeque::new();
157        dq.push_back(self.root);
158
159        while let Some(cur) = dq.pop_front() {
160            assert!(!cur.is_null());
161            for (&c, &next) in unsafe { (*cur).children.iter() } {
162                if cur == self.root {
163                    set_failure_link(next, cur);
164                } else {
165                    let mut i = failure_link_of(cur).unwrap();
166                    let mut j = self.root;
167
168                    loop {
169                        if let Some(t) = child_of(i, c) {
170                            j = t;
171                            break;
172                        }
173                        let Some(t) = failure_link_of(i) else {
174                            break;
175                        };
176                        i = t;
177                    }
178
179                    set_failure_link(next, j);
180                }
181
182                dq.push_back(next);
183            }
184        }
185
186        AhoCorasick {
187            size: self.size,
188            root: self.root,
189            dict: self.dict,
190            dict_index: self.dict_index,
191            nodes: self.nodes,
192        }
193    }
194}
195
196/// Aho-Corasick法
197pub struct AhoCorasick<K> {
198    size: usize,
199    root: *mut Node<K>,
200    dict: Vec<Vec<K>>,
201    dict_index: Vec<Vec<usize>>,
202    nodes: Vec<*mut Node<K>>,
203}
204
205#[allow(clippy::len_without_is_empty)]
206impl<K: Copy + Hash + Eq> AhoCorasick<K> {
207    /// ノード数を返す。
208    pub fn len(&self) -> usize {
209        self.size
210    }
211
212    /// Trie木の根ノードへの参照を返す。
213    pub fn root_node(&self) -> &Node<K> {
214        unsafe { &*self.root }
215    }
216
217    /// `index`番目に追加したパターンに対応するノードへの参照を返す。
218    pub fn node_of(&self, index: usize) -> &Node<K> {
219        assert!(!self.nodes[index].is_null());
220        unsafe { &*self.nodes[index] }
221    }
222
223    /// 文字列`s`がマッチするすべてのパターンを列挙する。
224    pub fn matches<I, F>(&self, s: I, mut proc: F)
225    where
226        I: IntoIterator<Item = K>,
227        F: FnMut(usize, std::ops::Range<usize>),
228    {
229        let mut cur = self.root;
230
231        for (i, c) in s.into_iter().enumerate() {
232            while cur != self.root && unsafe { !(*cur).children.contains_key(&c) } {
233                cur = failure_link_of(cur).unwrap();
234            }
235
236            cur = child_of(cur, c).unwrap_or(self.root);
237
238            let mut p = cur;
239
240            loop {
241                for &j in &self.dict_index[index_of(p)] {
242                    let len = self.dict[j].len();
243                    proc(j, i + 1 - len..i + 1);
244                }
245
246                let Some(q) = failure_link_of(p) else { break };
247                p = q;
248            }
249        }
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use std::ops::Range;
256
257    use super::*;
258
259    #[test]
260    fn test() {
261        let mut builder = AhoCorasickBuilder::new();
262
263        builder.add("ur".chars());
264        builder.add("et".chars());
265        builder.add("ur".chars());
266
267        let ac = builder.build();
268
269        let s = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
270        ac.matches(s.chars(), |index, range| {
271            let Range { start, end } = range;
272            println!(
273                "{} {}\x1b[m\x1b[1m\x1b[32m{}\x1b[m{}",
274                index,
275                s.get(start.saturating_sub(3)..start).unwrap(),
276                s.get(start..end).unwrap(),
277                s.get(end..end.saturating_add(3)).unwrap()
278            );
279        })
280    }
281}