haar_lib/ds/
qword_tree.rs

1//! 64分木
2
3/// `QwordTree`で扱える最大値
4pub const MAX: u32 = (1 << 24) - 1;
5
6/// 0 ~ 16777215 (2²⁴ - 1) の値の集合を管理する
7pub struct QwordTree {
8    v0: u64,
9    v1: Vec<u64>,
10    v2: Vec<u64>,
11    v3: Vec<u64>,
12    count: usize,
13}
14
15impl Default for QwordTree {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21impl QwordTree {
22    /// 64分木を生成
23    pub fn new() -> Self {
24        Self {
25            v0: 0,
26            v1: vec![0; 1 << 6],
27            v2: vec![0; 1 << 12],
28            v3: vec![0; 1 << 18],
29            count: 0,
30        }
31    }
32
33    /// # Safety
34    ///
35    /// `x`は`MAX`以下でなければならない。
36    ///
37    /// `x`はQwordTreeに含まれていない。
38    pub unsafe fn insert_unchecked(&mut self, x: u32) {
39        self.count += 1;
40
41        let x = x as usize;
42
43        *self.v3.get_unchecked_mut(x >> 6) |= 1 << (x & 0x3f);
44        *self.v2.get_unchecked_mut(x >> 12) |= 1 << ((x >> 6) & 0x3f);
45        *self.v1.get_unchecked_mut(x >> 18) |= 1 << ((x >> 12) & 0x3f);
46        self.v0 |= 1 << (x >> 18);
47    }
48
49    /// xを集合に加える
50    pub fn insert(&mut self, x: u32) -> bool {
51        if x > MAX || self.v3[x as usize >> 6] & (1 << (x & 0x3f)) != 0 {
52            false
53        } else {
54            unsafe {
55                self.insert_unchecked(x);
56            }
57
58            true
59        }
60    }
61
62    /// # Safety
63    ///
64    /// `x`は`MAX`以下でなければならない。
65    ///
66    /// `x`はQwordTreeに含まれている。
67    pub unsafe fn erase_unchecked(&mut self, x: u32) {
68        self.count -= 1;
69
70        let x = x as usize;
71
72        *self.v3.get_unchecked_mut(x >> 6) &= !(1 << (x & 0x3f));
73        if *self.v3.get_unchecked(x >> 6) == 0 {
74            *self.v2.get_unchecked_mut(x >> 12) &= !(1 << ((x >> 6) & 0x3f));
75        }
76        if *self.v2.get_unchecked(x >> 12) == 0 {
77            *self.v1.get_unchecked_mut(x >> 18) &= !(1 << ((x >> 12) & 0x3f));
78        }
79        if *self.v1.get_unchecked(x >> 18) == 0 {
80            self.v0 &= !(1 << (x >> 18));
81        }
82    }
83
84    /// xを集合から削除する
85    pub fn erase(&mut self, x: u32) -> bool {
86        if x > MAX || self.v3[x as usize >> 6] & (1 << (x & 0x3f)) == 0 {
87            false
88        } else {
89            unsafe {
90                self.erase_unchecked(x);
91            }
92
93            true
94        }
95    }
96
97    /// xを含むかどうかを判定する
98    pub fn contains(&self, x: u32) -> bool {
99        if x > MAX {
100            false
101        } else {
102            unsafe { self.v3.get_unchecked(x as usize >> 6) & (1 << (x & 0x3f)) != 0 }
103        }
104    }
105
106    /// 集合が空かどうかを判断する
107    pub fn is_empty(&self) -> bool {
108        self.count == 0
109    }
110
111    /// 集合に含まれている要素数を返す
112    pub fn len(&self) -> usize {
113        self.count
114    }
115
116    /// 最小値を返す
117    pub fn min(&self) -> Option<u32> {
118        if self.v0 == 0 {
119            None
120        } else {
121            let mut ret = self.v0.trailing_zeros();
122            unsafe {
123                ret = (ret << 6) | self.v1.get_unchecked(ret as usize).trailing_zeros();
124                ret = (ret << 6) | self.v2.get_unchecked(ret as usize).trailing_zeros();
125                ret = (ret << 6) | self.v3.get_unchecked(ret as usize).trailing_zeros();
126            }
127            Some(ret)
128        }
129    }
130
131    /// 最大値を返す
132    pub fn max(&self) -> Option<u32> {
133        if self.v0 == 0 {
134            None
135        } else {
136            let mut ret = 63 - self.v0.leading_zeros();
137            unsafe {
138                ret = (ret << 6) | (63 - self.v1.get_unchecked(ret as usize).leading_zeros());
139                ret = (ret << 6) | (63 - self.v2.get_unchecked(ret as usize).leading_zeros());
140                ret = (ret << 6) | (63 - self.v3.get_unchecked(ret as usize).leading_zeros());
141            }
142            Some(ret)
143        }
144    }
145
146    /// x以上で最小の値を返す
147    pub fn min_ge(&self, mut x: u32) -> Option<u32> {
148        if x > MAX {
149            return None;
150        }
151
152        let mask = !((1 << (x & 0x3f)) - 1);
153        let t = unsafe { (self.v3.get_unchecked(x as usize >> 6) & mask).trailing_zeros() };
154        if t != 64 {
155            return Some((x & !0x3f) | t);
156        }
157
158        x >>= 6;
159        let mask = (!0 << (x & 0x3f)) << 1;
160        let t = unsafe { (self.v2.get_unchecked(x as usize >> 6) & mask).trailing_zeros() };
161        if t != 64 {
162            let mut ret = (x & !0x3f) | t;
163            unsafe {
164                ret = (ret << 6) | self.v3.get_unchecked(ret as usize).trailing_zeros();
165            }
166            return Some(ret);
167        }
168
169        x >>= 6;
170        let mask = (!0 << (x & 0x3f)) << 1;
171        let t = unsafe { (self.v1.get_unchecked(x as usize >> 6) & mask).trailing_zeros() };
172        if t != 64 {
173            let mut ret = (x & !0x3f) | t;
174            unsafe {
175                ret = (ret << 6) | self.v2.get_unchecked(ret as usize).trailing_zeros();
176                ret = (ret << 6) | self.v3.get_unchecked(ret as usize).trailing_zeros();
177            }
178            return Some(ret);
179        }
180
181        x >>= 6;
182        let mask = (!0 << (x & 0x3f)) << 1;
183        let t = (self.v0 & mask).trailing_zeros();
184        if t != 64 {
185            let mut ret = t;
186            unsafe {
187                ret = (ret << 6) | self.v1.get_unchecked(ret as usize).trailing_zeros();
188                ret = (ret << 6) | self.v2.get_unchecked(ret as usize).trailing_zeros();
189                ret = (ret << 6) | self.v3.get_unchecked(ret as usize).trailing_zeros();
190            }
191            return Some(ret);
192        }
193
194        None
195    }
196
197    /// x以下で最大の値を返す
198    pub fn max_le(&self, mut x: u32) -> Option<u32> {
199        if x > MAX {
200            return None;
201        }
202
203        let mask = !((!0 << (x & 0x3f)) << 1);
204        let t = unsafe { (self.v3.get_unchecked(x as usize >> 6) & mask).leading_zeros() };
205        if t != 64 {
206            return Some((x & !0x3f) | (63 - t));
207        }
208
209        x >>= 6;
210        let mask = (1 << (x & 0x3f)) - 1;
211        let t = unsafe { (self.v2.get_unchecked(x as usize >> 6) & mask).leading_zeros() };
212        if t != 64 {
213            let mut ret = (x & !0x3f) | (63 - t);
214            unsafe {
215                ret = (ret << 6) | (63 - self.v3.get_unchecked(ret as usize).leading_zeros());
216            }
217            return Some(ret);
218        }
219
220        x >>= 6;
221        let mask = (1 << (x & 0x3f)) - 1;
222        let t = unsafe { (self.v1.get_unchecked(x as usize >> 6) & mask).leading_zeros() };
223        if t != 64 {
224            let mut ret = (x & !0x3f) | (63 - t);
225            unsafe {
226                ret = (ret << 6) | (63 - self.v2.get_unchecked(ret as usize).leading_zeros());
227                ret = (ret << 6) | (63 - self.v3.get_unchecked(ret as usize).leading_zeros());
228            }
229            return Some(ret);
230        }
231
232        x >>= 6;
233        let mask = (1 << (x & 0x3f)) - 1;
234        let t = (self.v0 & mask).leading_zeros();
235        if t != 64 {
236            let mut ret = 63 - t;
237            unsafe {
238                ret = (ret << 6) | (63 - self.v1.get_unchecked(ret as usize).leading_zeros());
239                ret = (ret << 6) | (63 - self.v2.get_unchecked(ret as usize).leading_zeros());
240                ret = (ret << 6) | (63 - self.v3.get_unchecked(ret as usize).leading_zeros());
241            }
242            return Some(ret);
243        }
244
245        None
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use rand::Rng;
253    use std::collections::BTreeSet;
254
255    #[test]
256    fn test() {
257        let mut rng = rand::thread_rng();
258
259        let mut set = BTreeSet::new();
260        let mut qt = QwordTree::new();
261
262        for _ in 0..5000 {
263            let x: u32 = rng.gen_range(0..1 << 12);
264            assert_eq!(set.insert(x), qt.insert(x));
265            assert_eq!(set.len(), qt.len());
266
267            assert_eq!(set.iter().next(), qt.min().as_ref());
268            assert_eq!(set.iter().next_back(), qt.max().as_ref());
269
270            let x: u32 = rng.gen_range(0..1 << 12);
271            assert_eq!(set.remove(&x), qt.erase(x));
272            assert_eq!(set.len(), qt.len());
273
274            assert_eq!(set.iter().next(), qt.min().as_ref());
275            assert_eq!(set.iter().next_back(), qt.max().as_ref());
276
277            let x: u32 = rng.gen_range(0..1 << 12);
278            assert_eq!(set.contains(&x), qt.contains(x));
279        }
280    }
281
282    #[test]
283    fn test_min_ge() {
284        let mut rng = rand::thread_rng();
285
286        let mut set = BTreeSet::new();
287        let mut qt = QwordTree::new();
288
289        for _ in 0..1000 {
290            let x: u32 = rng.gen_range(0..1 << 24);
291            set.insert(x);
292            qt.insert(x);
293
294            let x: u32 = rng.gen_range(0..1 << 24);
295            assert_eq!(set.range(x..).next(), qt.min_ge(x).as_ref());
296        }
297    }
298
299    #[test]
300    fn test_max_le() {
301        let mut rng = rand::thread_rng();
302
303        let mut set = BTreeSet::new();
304        let mut qt = QwordTree::new();
305
306        for _ in 0..1000 {
307            let x: u32 = rng.gen_range(0..1 << 24);
308            set.insert(x);
309            qt.insert(x);
310
311            let x: u32 = rng.gen_range(0..1 << 24);
312            assert_eq!(set.range(..=x).next_back(), qt.max_le(x).as_ref());
313        }
314    }
315}