haar_lib/ds/
wavelet_matrix.rs

1//! Wavelet matrix
2use crate::{ds::succinct_bitvec::SuccinctBitVec, misc::range::range_bounds_to_range};
3use std::{
4    marker::PhantomData,
5    ops::{BitAnd, BitOrAssign, RangeBounds, Shl, Shr},
6};
7
8/// Wavelet matrix
9#[derive(Clone)]
10pub struct WaveletMatrix<T, const BIT_SIZE: usize> {
11    size: usize,
12    sdict: Vec<SuccinctBitVec>,
13    zero_pos: Vec<usize>,
14    _phantom: PhantomData<T>,
15}
16
17impl<T, const BIT_SIZE: usize> WaveletMatrix<T, BIT_SIZE>
18where
19    T: Shr<usize, Output = T>
20        + Shl<usize, Output = T>
21        + BitAnd<Output = T>
22        + BitOrAssign
23        + From<u8>
24        + Eq
25        + Ord
26        + Copy,
27{
28    /// `T`の列から[`WaveletMatrix`]を作る。
29    pub fn new(mut data: Vec<T>) -> Self {
30        let size = data.len();
31
32        let mut sdict = vec![];
33        let mut zero_pos = vec![];
34
35        for k in 0..BIT_SIZE {
36            let mut left = vec![];
37            let mut right = vec![];
38            let mut s = vec![false; size];
39
40            for i in 0..size {
41                s[i] = (data[i] >> (BIT_SIZE - 1 - k)) & T::from(1) == T::from(1);
42                if s[i] {
43                    right.push(data[i]);
44                } else {
45                    left.push(data[i]);
46                }
47            }
48
49            sdict.push(SuccinctBitVec::new(s));
50            zero_pos.push(left.len());
51
52            data = left;
53            data.extend(right);
54        }
55
56        Self {
57            size,
58            sdict,
59            zero_pos,
60            _phantom: PhantomData,
61        }
62    }
63
64    /// `index`番目の値を得る。
65    pub fn access(&self, index: usize) -> T {
66        let mut ret = T::from(0);
67
68        let mut p = index;
69        for i in 0..BIT_SIZE {
70            let t = self.sdict[i].access(p);
71
72            ret |= T::from(t as u8) << (BIT_SIZE - 1 - i);
73            p = self.sdict[i].rank(p, t == 1) + t as usize * self.zero_pos[i];
74        }
75
76        ret
77    }
78
79    fn rank_(&self, index: usize, value: T) -> (usize, usize) {
80        let mut l = 0;
81        let mut r = index;
82
83        for i in 0..BIT_SIZE {
84            let t = (value >> (BIT_SIZE - 1 - i)) & T::from(1);
85
86            if t == T::from(1) {
87                l = self.sdict[i].rank(l, true) + self.zero_pos[i];
88                r = self.sdict[i].rank(r, true) + self.zero_pos[i];
89            } else {
90                l = self.sdict[i].rank(l, false);
91                r = self.sdict[i].rank(r, false);
92            }
93        }
94
95        (l, r)
96    }
97
98    /// [0, index)に含まれる`value`の個数。
99    pub fn rank(&self, index: usize, value: T) -> usize {
100        let (l, r) = self.rank_(index, value);
101        r - l
102    }
103
104    /// `range`に含まれる`value`の個数。
105    pub fn count(&self, range: impl RangeBounds<usize>, value: T) -> usize {
106        let (l, r) = range_bounds_to_range(range, 0, self.size);
107        self.rank(r, value) - self.rank(l, value)
108    }
109
110    /// `nth`(0-indexed)番目の`value`の位置。
111    pub fn select(&self, nth: usize, value: T) -> Option<usize> {
112        let nth = nth + 1;
113
114        let (l, r) = self.rank_(self.size, value);
115
116        if r - l < nth {
117            None
118        } else {
119            let mut p = l + nth - 1;
120
121            for i in (0..BIT_SIZE).rev() {
122                let t = (value >> (BIT_SIZE - i - 1)) & T::from(1);
123
124                if t == T::from(1) {
125                    p = self.sdict[i].select(p - self.zero_pos[i], true).unwrap();
126                } else {
127                    p = self.sdict[i].select(p, false).unwrap();
128                }
129            }
130
131            Some(p)
132        }
133    }
134
135    /// `range`で`nth`(0-indexed)番目に小さい値。
136    pub fn quantile(&self, range: impl RangeBounds<usize>, nth: usize) -> Option<T> {
137        let (mut l, mut r) = range_bounds_to_range(range, 0, self.size);
138        if nth >= r - l {
139            return None;
140        }
141
142        let mut nth = nth + 1;
143        let mut ret = T::from(0);
144
145        for (i, sdict) in self.sdict.iter().enumerate() {
146            let count_1 = sdict.count(l..r, true);
147            let count_0 = r - l - count_1;
148
149            let mut t = 0;
150
151            if nth > count_0 {
152                t = 1;
153                ret |= T::from(1) << (BIT_SIZE - i - 1);
154                nth -= count_0;
155            }
156
157            let zeropos = unsafe { self.zero_pos.get_unchecked(i) };
158            l = sdict.rank(l, t == 1) + t as usize * zeropos;
159            r = sdict.rank(r, t == 1) + t as usize * zeropos;
160        }
161
162        Some(ret)
163    }
164
165    /// `range`での最大値
166    pub fn maximum(&self, range: impl RangeBounds<usize>) -> Option<T> {
167        let (l, r) = range_bounds_to_range(range, 0, self.size);
168        if r > l {
169            self.quantile(l..r, r - l - 1)
170        } else {
171            None
172        }
173    }
174
175    /// `range`での最小値
176    pub fn minimum(&self, range: impl RangeBounds<usize>) -> Option<T> {
177        self.quantile(range, 0)
178    }
179
180    fn range_freq_lt(&self, range: impl RangeBounds<usize>, ub: T) -> usize {
181        let (mut l, mut r) = range_bounds_to_range(range, 0, self.size);
182        let mut ret = 0;
183        for i in 0..BIT_SIZE {
184            let t = (ub >> (BIT_SIZE - i - 1)) & T::from(1);
185            if t == T::from(1) {
186                ret += self.sdict[i].count(l..r, false);
187                l = self.sdict[i].rank(l, true) + self.zero_pos[i];
188                r = self.sdict[i].rank(r, true) + self.zero_pos[i];
189            } else {
190                l = self.sdict[i].rank(l, false);
191                r = self.sdict[i].rank(r, false);
192            }
193        }
194        ret
195    }
196
197    /// `range`で`lb`以上の最小値
198    pub fn next_value(&self, range: impl RangeBounds<usize> + Clone, lb: T) -> Option<T> {
199        let c = self.range_freq_lt(range.clone(), lb);
200        self.quantile(range, c)
201    }
202
203    /// `range`で`ub`未満の最大値
204    pub fn prev_value(&self, range: impl RangeBounds<usize> + Clone, ub: T) -> Option<T> {
205        let c = self.range_freq_lt(range.clone(), ub);
206        if c == 0 {
207            None
208        } else {
209            self.quantile(range, c - 1)
210        }
211    }
212
213    /// `range`で`lb`以上`ub`未満の値の個数
214    pub fn range_freq(&self, range: impl RangeBounds<usize> + Clone, lb: T, ub: T) -> usize {
215        if lb >= ub {
216            return 0;
217        }
218        self.range_freq_lt(range.clone(), ub) - self.range_freq_lt(range, lb)
219    }
220}
221
222/// [`u64`]の列を管理できる[`WaveletMatrix`]
223pub type WM64 = WaveletMatrix<u64, 64>;
224/// [`u32`]の列を管理できる[`WaveletMatrix`]
225pub type WM32 = WaveletMatrix<u32, 32>;
226
227#[cfg(test)]
228mod tests {
229    #![allow(clippy::needless_range_loop)]
230    use super::*;
231    use crate::algo::bsearch_slice::BinarySearch;
232    use my_testtools::*;
233    use rand::Rng;
234
235    #[test]
236    fn test_access() {
237        let mut rng = rand::thread_rng();
238        let n = 10000;
239        let b = (0..n).map(|_| rng.gen::<u64>()).collect::<Vec<_>>();
240
241        let wm = WM64::new(b.clone());
242
243        for i in 0..n {
244            assert_eq!(wm.access(i), b[i]);
245        }
246    }
247
248    #[test]
249    fn test_rank() {
250        let mut rng = rand::thread_rng();
251
252        let m = 50;
253        let table = (0..m).map(|_| rng.gen::<u64>()).collect::<Vec<_>>();
254
255        let n = 300;
256        let b = (0..n)
257            .map(|_| table[rng.gen::<usize>() % m])
258            .collect::<Vec<_>>();
259
260        let wm = WM64::new(b.clone());
261
262        for k in 0..m {
263            let mut count = 0;
264            for i in 0..=n {
265                assert_eq!(wm.rank(i, table[k]), count);
266                if b.get(i) == Some(&table[k]) {
267                    count += 1;
268                }
269            }
270        }
271    }
272
273    #[test]
274    fn test_count() {
275        let mut rng = rand::thread_rng();
276
277        let m = 50;
278        let table = (0..m).map(|_| rng.gen::<u64>()).collect::<Vec<_>>();
279
280        let n = 300;
281        let b = (0..n)
282            .map(|_| table[rng.gen::<usize>() % m])
283            .collect::<Vec<_>>();
284
285        let wm = WM64::new(b.clone());
286
287        for _ in 0..1000 {
288            let lr = rand_range(&mut rng, 0..n);
289            let x = table[rng.gen::<usize>() % m];
290
291            let count = b[lr.clone()].iter().filter(|&&y| x == y).count();
292
293            assert_eq!(wm.count(lr, x), count);
294        }
295    }
296
297    #[test]
298    fn test_select() {
299        let mut rng = rand::thread_rng();
300
301        let m = 50;
302        let table = (0..m).map(|_| rng.gen::<u64>()).collect::<Vec<_>>();
303
304        let n = 300;
305        let b = (0..n)
306            .map(|_| table[rng.gen::<usize>() % m])
307            .collect::<Vec<_>>();
308
309        let wm = WM64::new(b.clone());
310
311        for x in table {
312            let count = wm.count(.., x);
313
314            assert_eq!(
315                (0..count)
316                    .map(|i| wm.select(i, x).unwrap())
317                    .collect::<Vec<_>>(),
318                (0..n).filter(|&i| b[i] == x).collect::<Vec<_>>()
319            );
320        }
321    }
322
323    #[test]
324    fn test_quantile() {
325        let mut rng = rand::thread_rng();
326
327        let m = 50;
328        let table = (0..m).map(|_| rng.gen::<u64>()).collect::<Vec<_>>();
329
330        let n = 300;
331        let b = (0..n)
332            .map(|_| table[rng.gen::<usize>() % m])
333            .collect::<Vec<_>>();
334
335        let wm = WM64::new(b.clone());
336
337        for _ in 0..300 {
338            let lr = rand_range(&mut rng, 0..n);
339
340            let mut a = b[lr.clone()].to_vec();
341            a.sort();
342
343            assert_eq!(
344                (0..lr.end - lr.start)
345                    .map(|i| wm.quantile(lr.clone(), i).unwrap())
346                    .collect::<Vec<_>>(),
347                a
348            );
349
350            assert_eq!(wm.maximum(lr.clone()), a.last().copied());
351            assert_eq!(wm.minimum(lr.clone()), a.first().copied());
352        }
353    }
354
355    #[test]
356    fn test_prev_next_value() {
357        let mut rng = rand::thread_rng();
358
359        let m = 50;
360        let table = (0..m).map(|_| rng.gen::<u64>()).collect::<Vec<_>>();
361
362        let n = 300;
363        let b = (0..n)
364            .map(|_| table[rng.gen::<usize>() % m])
365            .collect::<Vec<_>>();
366
367        let wm = WM64::new(b.clone());
368
369        for _ in 0..1000 {
370            let lr = rand_range(&mut rng, 0..n);
371
372            let mut a = b[lr.clone()].to_vec();
373            a.sort();
374
375            let x = rng.gen::<u64>();
376            let i = a.lower_bound(&x);
377
378            assert_eq!(wm.next_value(lr.clone(), x), a.get(i).copied());
379
380            let i = a.lower_bound(&x);
381
382            assert_eq!(
383                wm.prev_value(lr, x),
384                if i == 0 { None } else { a.get(i - 1).copied() }
385            );
386        }
387    }
388
389    #[test]
390    fn test_range_freq() {
391        let mut rng = rand::thread_rng();
392
393        let m = 50;
394        let table = (0..m).map(|_| rng.gen::<u64>()).collect::<Vec<_>>();
395
396        let n = 300;
397        let b = (0..n)
398            .map(|_| table[rng.gen::<usize>() % m])
399            .collect::<Vec<_>>();
400
401        let wm = WM64::new(b.clone());
402
403        for _ in 0..1000 {
404            let lr = rand_range(&mut rng, 0..n);
405            let lb = rng.gen::<u64>();
406            let ub = rng.gen::<u64>();
407
408            assert_eq!(
409                wm.range_freq(lr.clone(), lb, ub),
410                b[lr].iter().filter(|&&x| lb <= x && x < ub).count()
411            );
412        }
413    }
414}