haar_lib/ds/
binary_trie.rs1use crate::misc::nullable_usize::NullableUsize;
3
4#[derive(Debug, Clone)]
5struct Node {
6 ch: [NullableUsize; 2],
7 count: usize,
8}
9
10impl Default for Node {
11 fn default() -> Self {
12 Self {
13 ch: [NullableUsize::NULL, NullableUsize::NULL],
14 count: 0,
15 }
16 }
17}
18
19#[derive(Debug, Clone)]
21pub struct BinaryTrie {
22 data: Vec<Node>,
23 bitlen: usize,
24}
25
26impl BinaryTrie {
27 pub fn new(bitlen: usize) -> Self {
29 assert!(bitlen <= 64);
30 let data = vec![Node::default()];
31 Self { data, bitlen }
32 }
33
34 pub fn len(&self) -> usize {
36 self.data[0].count
37 }
38
39 pub fn is_empty(&self) -> bool {
41 self.data[0].count == 0
42 }
43
44 pub fn count(&self, value: u64) -> usize {
46 let mut node = 0;
47 let mut depth = self.bitlen;
48
49 while depth > 0 {
50 depth -= 1;
51 let b = (value >> depth) & 1;
52
53 let t = self.data[node].ch[b as usize];
54 if t.is_null() {
55 return 0;
56 }
57 node = t.0;
58 }
59
60 self.data[node].count
61 }
62
63 pub fn insert(&mut self, value: u64) -> usize {
65 let mut node = 0;
66 let mut depth = self.bitlen;
67
68 while depth > 0 {
69 self.data[node].count += 1;
70 depth -= 1;
71
72 let b = (value >> depth) & 1;
73
74 let ch = self.data[node].ch[b as usize];
75 if !ch.is_null() {
76 node = ch.0;
77 } else {
78 self.data.push(Node::default());
79 let ch = self.data.len() - 1;
80 self.data[node].ch[b as usize] = NullableUsize(ch);
81 node = ch;
82 }
83 }
84
85 self.data[node].count += 1;
86 self.data[node].count
87 }
88
89 pub fn erase(&mut self, value: u64) -> Option<usize> {
92 let mut node = 0;
93 let mut depth = self.bitlen;
94 let mut path = vec![];
95
96 while depth > 0 {
97 depth -= 1;
98 let b = (value >> depth) & 1;
99
100 path.push(node);
101
102 let ch = self.data[node].ch[b as usize];
103 if !ch.is_null() {
104 node = ch.0;
105 } else {
106 self.data.push(Node::default());
107 let ch = self.data.len() - 1;
108 self.data[node].ch[b as usize] = NullableUsize(ch);
109 node = ch;
110 }
111 }
112
113 (self.data[node].count > 0).then(|| {
114 path.push(node);
115 for a in path {
116 self.data[a].count -= 1;
117 }
118 self.data[node].count
119 })
120 }
121
122 pub fn min(&mut self, xor: u64) -> Option<u64> {
124 if self.data[0].count == 0 {
125 None
126 } else {
127 let mut node = 0;
128 let mut depth = self.bitlen;
129 let mut ret = 0;
130
131 while depth > 0 {
132 depth -= 1;
133
134 let mut b = (xor >> depth) & 1;
135
136 let t = self.data[node].ch[b as usize];
137 if t.is_null() || self.data[t.0].count == 0 {
138 b ^= 1;
139 }
140
141 node = self.data[node].ch[b as usize].0;
142 ret |= b << depth;
143 }
144
145 Some(ret)
146 }
147 }
148
149 pub fn max(&mut self, xor: u64) -> Option<u64> {
151 if self.data[0].count == 0 {
152 None
153 } else {
154 let mut node = 0;
155 let mut depth = self.bitlen;
156 let mut ret = 0;
157
158 while depth > 0 {
159 depth -= 1;
160
161 let mut b = ((xor >> depth) & 1) ^ 1;
162
163 let t = self.data[node].ch[b as usize];
164 if t.is_null() || self.data[t.0].count == 0 {
165 b ^= 1;
166 }
167
168 node = self.data[node].ch[b as usize].0;
169 ret |= b << depth;
170 }
171
172 Some(ret)
173 }
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use rand::Rng;
181 use std::collections::BTreeMap;
182
183 #[test]
184 fn test() {
185 let mut rng = rand::thread_rng();
186
187 let mut bt = BinaryTrie::new(64);
188 let mut m = BTreeMap::new();
189
190 for _ in 0..1000 {
191 let x = rng.gen_range(0..100);
192
193 bt.insert(x);
194 *m.entry(x).or_insert(0) += 1;
195
196 let y = rng.gen::<u64>();
197
198 assert_eq!(
199 bt.min(y),
200 m.iter().map(|(&a, _)| a).min_by_key(|&a| (a ^ y))
201 );
202
203 assert_eq!(
204 bt.max(y),
205 m.iter().map(|(&a, _)| a).max_by_key(|&a| (a ^ y))
206 );
207
208 assert_eq!(
209 (0..100).map(|i| bt.count(i)).collect::<Vec<_>>(),
210 (0..100)
211 .map(|i| *m.get(&i).unwrap_or(&0))
212 .collect::<Vec<_>>()
213 );
214
215 let x = rng.gen_range(0..100);
216
217 assert_eq!(bt.erase(x).unwrap_or(0), bt.count(x));
218 match m.get_mut(&x) {
219 Some(y) if *y >= 1 => {
220 *y -= 1;
221 if *y == 0 {
222 m.remove(&x);
223 }
224 }
225 _ => {}
226 }
227 }
228 }
229}