haar_lib/ds/
trie.rs

1//! Trie木
2//!
3//! # Problems
4//! - <https://atcoder.jp/contests/abc353/tasks/abc353_e>
5//! - <https://atcoder.jp/contests/abc377/tasks/abc377_g>
6use std::collections::HashMap;
7use std::hash::Hash;
8
9/// Trie木のノード
10#[derive(Clone, Debug)]
11pub struct TrieNode<T, K> {
12    /// ノードに格納している値
13    pub value: T,
14    children: HashMap<K, *mut Self>,
15}
16
17impl<T, K> TrieNode<T, K> {
18    fn new(value: T) -> Self {
19        Self {
20            value,
21            children: HashMap::default(),
22        }
23    }
24}
25
26impl<T, K: Copy + Hash + Eq> TrieNode<T, K> {
27    /// 子ノードへのキーと子ノードへの参照をもつイテレータを返す。
28    pub fn children_nodes(&self) -> impl Iterator<Item = (K, &Self)> {
29        self.children
30            .clone()
31            .into_iter()
32            .map(|(k, v)| (k, unsafe { &*v }))
33    }
34
35    /// 子ノードへのキーと子ノードへの可変参照をもつイテレータを返す。
36    pub fn children_nodes_mut(&mut self) -> impl Iterator<Item = (K, &mut Self)> {
37        self.children
38            .clone()
39            .into_iter()
40            .map(|(k, v)| (k, unsafe { &mut *v }))
41    }
42
43    fn add<I, FI, F1, F2>(
44        &mut self,
45        iter: &mut I,
46        mut init: FI,
47        mut proc: F1,
48        rproc: &mut F2,
49        prefix: &mut Vec<K>,
50    ) where
51        I: Iterator<Item = K>,
52        FI: FnMut(&Vec<K>) -> T,
53        F1: FnMut(&mut T, &Vec<K>),
54        F2: FnMut(&mut T, &Vec<K>),
55    {
56        proc(&mut self.value, prefix);
57        if let Some(c) = iter.next() {
58            prefix.push(c);
59
60            let next = if let Some(&next) = self.children.get(&c) {
61                next
62            } else {
63                let value = init(prefix);
64                let next = Box::new(Self::new(value));
65                let next = Box::into_raw(next);
66                self.children.insert(c, next);
67                next
68            };
69
70            assert!(!next.is_null());
71            unsafe { &mut *next }.add(iter, init, proc, rproc, prefix);
72            prefix.pop();
73        }
74        rproc(&mut self.value, prefix);
75    }
76}
77
78/// Trie木
79pub struct Trie<T, K> {
80    root: *mut TrieNode<T, K>,
81}
82
83impl<T, K: Copy + Hash + Eq> Trie<T, K> {
84    /// 値`value`を保持するルートのみをもつ[`Trie`]を構築する。
85    pub fn new(value: T) -> Self {
86        let root = Box::new(TrieNode::new(value));
87        let root = Box::into_raw(root);
88        Self { root }
89    }
90
91    /// Trie木の根ノードへの参照を返す。
92    pub fn root_node(&self) -> &TrieNode<T, K> {
93        assert!(!self.root.is_null());
94        unsafe { &*self.root }
95    }
96
97    /// Trie木の根ノードへの可変参照を返す。
98    pub fn root_node_mut(&mut self) -> &mut TrieNode<T, K> {
99        assert!(!self.root.is_null());
100        unsafe { &mut *self.root }
101    }
102
103    /// 列`s`をTrie木に追加する。
104    ///
105    /// - `init`: ノードが新しく追加されるときの初期値を決定する。
106    /// - `proc`: 行きがけ順に、ノードの値に処理をする。
107    /// - `rproc`: 帰りがけ順に、ノードの値に処理をする。
108    pub fn add<I, FI, F1, F2>(&mut self, s: I, init: FI, proc: F1, mut rproc: F2)
109    where
110        I: IntoIterator<Item = K>,
111        FI: FnMut(&Vec<K>) -> T,
112        F1: FnMut(&mut T, &Vec<K>),
113        F2: FnMut(&mut T, &Vec<K>),
114    {
115        let mut s = s.into_iter();
116        assert!(!self.root.is_null());
117        unsafe { &mut *self.root }.add(&mut s, init, proc, &mut rproc, &mut vec![]);
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use std::fmt::{Debug, Display};
125
126    fn dfs<T, K>(node: &mut TrieNode<T, K>, prefix: &mut Vec<K>)
127    where
128        T: Default + Display,
129        K: Copy + Hash + Eq + Debug,
130    {
131        let depth = prefix.len();
132        let key = prefix.last();
133        println!("{:->depth$} {:?} {}", "", key, node.value);
134        for (key, ch) in node.children_nodes_mut() {
135            prefix.push(key);
136            dfs(ch, prefix);
137            prefix.pop();
138        }
139    }
140
141    #[test]
142    fn test() {
143        let mut trie = Trie::<u32, char>::new(0);
144
145        let init = |prefix: &Vec<char>| -> u32 { prefix.len() as u32 };
146
147        let proc = |value: &mut u32, prefix: &Vec<char>| {
148            println!("{:?}", prefix);
149            *value += 1;
150        };
151
152        let rproc = |_: &mut u32, prefix: &Vec<char>| {
153            println!("{:?}", prefix);
154        };
155
156        trie.add("abc".chars(), init, proc, rproc);
157        trie.add("abra".chars(), init, proc, rproc);
158        trie.add("baa".chars(), init, proc, rproc);
159
160        let root = trie.root_node_mut();
161        dfs(root, &mut vec![]);
162    }
163}