haar_lib/ds/
binary_trie.rs

1//! 非負整数を2進数として管理する。
2use crate::misc::nullable_usize::NullableUsize;
3
4#[derive(Debug, Clone)]
5struct Node {
6    ch: [NullableUsize; 2],
7    count: usize,
8}
9
10impl Default for Node {
11    fn default() -> Self {
12        Self {
13            ch: [NullableUsize::NULL, NullableUsize::NULL],
14            count: 0,
15        }
16    }
17}
18
19/// 非負整数を2進数として管理する。
20#[derive(Debug, Clone)]
21pub struct BinaryTrie {
22    data: Vec<Node>,
23    bitlen: usize,
24}
25
26impl BinaryTrie {
27    /// `bitlen`ビットの数を扱える[`BinaryTrie`]を生成する。
28    pub fn new(bitlen: usize) -> Self {
29        assert!(bitlen <= 64);
30        let data = vec![Node::default()];
31        Self { data, bitlen }
32    }
33
34    /// 要素数を返す。
35    pub fn len(&self) -> usize {
36        self.data[0].count
37    }
38
39    /// 要素数が0ならば`true`を返す。
40    pub fn is_empty(&self) -> bool {
41        self.data[0].count == 0
42    }
43
44    /// 値`value`の個数を返す。
45    pub fn count(&self, value: u64) -> usize {
46        let mut node = 0;
47        let mut depth = self.bitlen;
48
49        while depth > 0 {
50            depth -= 1;
51            let b = (value >> depth) & 1;
52
53            let t = self.data[node].ch[b as usize];
54            if t.is_null() {
55                return 0;
56            }
57            node = t.0;
58        }
59
60        self.data[node].count
61    }
62
63    /// 値`value`を挿入して、`value`の個数を返す。
64    pub fn insert(&mut self, value: u64) -> usize {
65        let mut node = 0;
66        let mut depth = self.bitlen;
67
68        while depth > 0 {
69            self.data[node].count += 1;
70            depth -= 1;
71
72            let b = (value >> depth) & 1;
73
74            let ch = self.data[node].ch[b as usize];
75            if !ch.is_null() {
76                node = ch.0;
77            } else {
78                self.data.push(Node::default());
79                let ch = self.data.len() - 1;
80                self.data[node].ch[b as usize] = NullableUsize(ch);
81                node = ch;
82            }
83        }
84
85        self.data[node].count += 1;
86        self.data[node].count
87    }
88
89    /// 値`value`が存在すれば、一つ削除して、削除後の`value`の個数を`Some`に包んで返す。
90    /// 存在しなければ、`None`を返す。
91    pub fn erase(&mut self, value: u64) -> Option<usize> {
92        let mut node = 0;
93        let mut depth = self.bitlen;
94        let mut path = vec![];
95
96        while depth > 0 {
97            depth -= 1;
98            let b = (value >> depth) & 1;
99
100            path.push(node);
101
102            let ch = self.data[node].ch[b as usize];
103            if !ch.is_null() {
104                node = ch.0;
105            } else {
106                self.data.push(Node::default());
107                let ch = self.data.len() - 1;
108                self.data[node].ch[b as usize] = NullableUsize(ch);
109                node = ch;
110            }
111        }
112
113        (self.data[node].count > 0).then(|| {
114            path.push(node);
115            for a in path {
116                self.data[a].count -= 1;
117            }
118            self.data[node].count
119        })
120    }
121
122    /// $\min_{a \in S} a \oplus xor$を求める。
123    pub fn min(&mut self, xor: u64) -> Option<u64> {
124        if self.data[0].count == 0 {
125            None
126        } else {
127            let mut node = 0;
128            let mut depth = self.bitlen;
129            let mut ret = 0;
130
131            while depth > 0 {
132                depth -= 1;
133
134                let mut b = (xor >> depth) & 1;
135
136                let t = self.data[node].ch[b as usize];
137                if t.is_null() || self.data[t.0].count == 0 {
138                    b ^= 1;
139                }
140
141                node = self.data[node].ch[b as usize].0;
142                ret |= b << depth;
143            }
144
145            Some(ret)
146        }
147    }
148
149    /// $\max_{a \in S} a \oplus xor$を求める。
150    pub fn max(&mut self, xor: u64) -> Option<u64> {
151        if self.data[0].count == 0 {
152            None
153        } else {
154            let mut node = 0;
155            let mut depth = self.bitlen;
156            let mut ret = 0;
157
158            while depth > 0 {
159                depth -= 1;
160
161                let mut b = ((xor >> depth) & 1) ^ 1;
162
163                let t = self.data[node].ch[b as usize];
164                if t.is_null() || self.data[t.0].count == 0 {
165                    b ^= 1;
166                }
167
168                node = self.data[node].ch[b as usize].0;
169                ret |= b << depth;
170            }
171
172            Some(ret)
173        }
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use rand::Rng;
181    use std::collections::BTreeMap;
182
183    #[test]
184    fn test() {
185        let mut rng = rand::thread_rng();
186
187        let mut bt = BinaryTrie::new(64);
188        let mut m = BTreeMap::new();
189
190        for _ in 0..1000 {
191            let x = rng.gen_range(0..100);
192
193            bt.insert(x);
194            *m.entry(x).or_insert(0) += 1;
195
196            let y = rng.gen::<u64>();
197
198            assert_eq!(
199                bt.min(y),
200                m.iter().map(|(&a, _)| a).min_by_key(|&a| (a ^ y))
201            );
202
203            assert_eq!(
204                bt.max(y),
205                m.iter().map(|(&a, _)| a).max_by_key(|&a| (a ^ y))
206            );
207
208            assert_eq!(
209                (0..100).map(|i| bt.count(i)).collect::<Vec<_>>(),
210                (0..100)
211                    .map(|i| *m.get(&i).unwrap_or(&0))
212                    .collect::<Vec<_>>()
213            );
214
215            let x = rng.gen_range(0..100);
216
217            assert_eq!(bt.erase(x).unwrap_or(0), bt.count(x));
218            match m.get_mut(&x) {
219                Some(y) if *y >= 1 => {
220                    *y -= 1;
221                    if *y == 0 {
222                        m.remove(&x);
223                    }
224                }
225                _ => {}
226            }
227        }
228    }
229}