haar_lib/ds/
wavelet_matrix.rs

1//! Wavelet matrix
2use crate::{ds::succinct_bitvec::SuccinctBitVec, misc::range::range_bounds_to_range};
3use std::{
4    marker::PhantomData,
5    ops::{BitAnd, BitOrAssign, RangeBounds, Shl, Shr},
6};
7
8/// Wavelet matrix
9#[derive(Clone)]
10pub struct WaveletMatrix<T, const BIT_SIZE: usize> {
11    size: usize,
12    sdict: Vec<SuccinctBitVec>,
13    zero_pos: Vec<usize>,
14    _phantom: PhantomData<T>,
15}
16
17impl<T, const BIT_SIZE: usize> WaveletMatrix<T, BIT_SIZE>
18where
19    T: Shr<usize, Output = T>
20        + Shl<usize, Output = T>
21        + BitAnd<Output = T>
22        + BitOrAssign
23        + From<u8>
24        + Eq
25        + Ord
26        + Copy,
27{
28    /// `T`の列から[`WaveletMatrix`]を作る。
29    pub fn new(mut data: Vec<T>) -> Self {
30        let size = data.len();
31
32        let mut sdict = vec![];
33        let mut zero_pos = vec![];
34
35        for k in 0..BIT_SIZE {
36            let mut left = vec![];
37            let mut right = vec![];
38            let mut s = vec![false; size];
39
40            for i in 0..size {
41                s[i] = (data[i] >> (BIT_SIZE - 1 - k)) & T::from(1) == T::from(1);
42                if s[i] {
43                    right.push(data[i]);
44                } else {
45                    left.push(data[i]);
46                }
47            }
48
49            sdict.push(SuccinctBitVec::new(s));
50            zero_pos.push(left.len());
51
52            data = left;
53            data.extend(right);
54        }
55
56        Self {
57            size,
58            sdict,
59            zero_pos,
60            _phantom: PhantomData,
61        }
62    }
63
64    /// `index`番目の値を得る。
65    pub fn access(&self, index: usize) -> T {
66        let mut ret = T::from(0);
67
68        let mut p = index;
69        for i in 0..BIT_SIZE {
70            let t = self.sdict[i].access(p);
71
72            ret |= T::from(t as u8) << (BIT_SIZE - 1 - i);
73            p = self.sdict[i].rank(p, t == 1) + t as usize * self.zero_pos[i];
74        }
75
76        ret
77    }
78
79    fn rank_(&self, index: usize, value: T) -> (usize, usize) {
80        let mut l = 0;
81        let mut r = index;
82
83        for i in 0..BIT_SIZE {
84            let t = (value >> (BIT_SIZE - 1 - i)) & T::from(1);
85
86            if t == T::from(1) {
87                l = self.sdict[i].rank(l, true) + self.zero_pos[i];
88                r = self.sdict[i].rank(r, true) + self.zero_pos[i];
89            } else {
90                l = self.sdict[i].rank(l, false);
91                r = self.sdict[i].rank(r, false);
92            }
93        }
94
95        (l, r)
96    }
97
98    /// [0, index)に含まれる`value`の個数。
99    pub fn rank(&self, index: usize, value: T) -> usize {
100        let (l, r) = self.rank_(index, value);
101        r - l
102    }
103
104    /// `range`に含まれる`value`の個数。
105    pub fn count(&self, range: impl RangeBounds<usize>, value: T) -> usize {
106        let (l, r) = range_bounds_to_range(range, 0, self.size);
107        self.rank(r, value) - self.rank(l, value)
108    }
109
110    /// `nth`(0-indexed)番目の`value`の位置。
111    pub fn select(&self, nth: usize, value: T) -> Option<usize> {
112        let nth = nth + 1;
113
114        let (l, r) = self.rank_(self.size, value);
115
116        if r - l < nth {
117            None
118        } else {
119            let mut p = l + nth - 1;
120
121            for i in (0..BIT_SIZE).rev() {
122                let t = (value >> (BIT_SIZE - i - 1)) & T::from(1);
123
124                if t == T::from(1) {
125                    p = self.sdict[i].select(p - self.zero_pos[i], true).unwrap();
126                } else {
127                    p = self.sdict[i].select(p, false).unwrap();
128                }
129            }
130
131            Some(p)
132        }
133    }
134
135    /// `range`で`nth`(0-indexed)番目に小さい値。
136    pub fn quantile(&self, range: impl RangeBounds<usize>, nth: usize) -> Option<T> {
137        let (mut l, mut r) = range_bounds_to_range(range, 0, self.size);
138        if nth >= r - l {
139            return None;
140        }
141
142        let mut nth = nth + 1;
143        let mut ret = T::from(0);
144
145        for (i, sdict) in self.sdict.iter().enumerate() {
146            let count_1 = sdict.count(l..r, true);
147            let count_0 = r - l - count_1;
148
149            let mut t = 0;
150
151            if nth > count_0 {
152                t = 1;
153                ret |= T::from(1) << (BIT_SIZE - i - 1);
154                nth -= count_0;
155            }
156
157            let zeropos = unsafe { self.zero_pos.get_unchecked(i) };
158            l = sdict.rank(l, t == 1) + t as usize * zeropos;
159            r = sdict.rank(r, t == 1) + t as usize * zeropos;
160        }
161
162        Some(ret)
163    }
164
165    /// `range`での最大値
166    pub fn maximum(&self, range: impl RangeBounds<usize>) -> Option<T> {
167        let (l, r) = range_bounds_to_range(range, 0, self.size);
168        if r > l {
169            self.quantile(l..r, r - l - 1)
170        } else {
171            None
172        }
173    }
174
175    /// `range`での最小値
176    pub fn minimum(&self, range: impl RangeBounds<usize>) -> Option<T> {
177        self.quantile(range, 0)
178    }
179
180    fn range_freq_lt(&self, range: impl RangeBounds<usize>, ub: T) -> usize {
181        let (mut l, mut r) = range_bounds_to_range(range, 0, self.size);
182        let mut ret = 0;
183        for i in 0..BIT_SIZE {
184            let t = (ub >> (BIT_SIZE - i - 1)) & T::from(1);
185            if t == T::from(1) {
186                ret += self.sdict[i].count(l..r, false);
187                l = self.sdict[i].rank(l, true) + self.zero_pos[i];
188                r = self.sdict[i].rank(r, true) + self.zero_pos[i];
189            } else {
190                l = self.sdict[i].rank(l, false);
191                r = self.sdict[i].rank(r, false);
192            }
193        }
194        ret
195    }
196
197    /// `range`で`lb`以上の最小値
198    pub fn next_value(&self, range: impl RangeBounds<usize> + Clone, lb: T) -> Option<T> {
199        let c = self.range_freq_lt(range.clone(), lb);
200        self.quantile(range, c)
201    }
202
203    /// `range`で`ub`未満の最大値
204    pub fn prev_value(&self, range: impl RangeBounds<usize> + Clone, ub: T) -> Option<T> {
205        let c = self.range_freq_lt(range.clone(), ub);
206        if c == 0 {
207            None
208        } else {
209            self.quantile(range, c - 1)
210        }
211    }
212
213    /// `range`で`lb`以上`ub`未満の値の個数
214    pub fn range_freq(&self, range: impl RangeBounds<usize> + Clone, lb: T, ub: T) -> usize {
215        if lb >= ub {
216            return 0;
217        }
218        self.range_freq_lt(range.clone(), ub) - self.range_freq_lt(range, lb)
219    }
220}
221
222/// [`u64`]の列を管理できる[`WaveletMatrix`]
223pub type WM64 = WaveletMatrix<u64, 64>;
224/// [`u32`]の列を管理できる[`WaveletMatrix`]
225pub type WM32 = WaveletMatrix<u32, 32>;
226
227#[cfg(test)]
228mod tests {
229    #![allow(clippy::needless_range_loop)]
230    use super::*;
231    use crate::algo::bsearch_slice::BinarySearch;
232    use my_testtools::*;
233    use rand::Rng;
234
235    #[test]
236    fn test_access() {
237        let mut rng = rand::thread_rng();
238        let n = 10000;
239        let b = std::iter::repeat_with(|| rng.gen::<u64>())
240            .take(n)
241            .collect::<Vec<_>>();
242
243        let wm = WM64::new(b.clone());
244
245        for i in 0..n {
246            assert_eq!(wm.access(i), b[i]);
247        }
248    }
249
250    #[test]
251    fn test_rank() {
252        let mut rng = rand::thread_rng();
253
254        let m = 50;
255        let table = std::iter::repeat_with(|| rng.gen::<u64>())
256            .take(m)
257            .collect::<Vec<_>>();
258
259        let n = 300;
260        let b = std::iter::repeat_with(|| table[rng.gen::<usize>() % m])
261            .take(n)
262            .collect::<Vec<_>>();
263
264        let wm = WM64::new(b.clone());
265
266        for k in 0..m {
267            let mut count = 0;
268            for i in 0..=n {
269                assert_eq!(wm.rank(i, table[k]), count);
270                if b.get(i) == Some(&table[k]) {
271                    count += 1;
272                }
273            }
274        }
275    }
276
277    #[test]
278    fn test_count() {
279        let mut rng = rand::thread_rng();
280
281        let m = 50;
282        let table = std::iter::repeat_with(|| rng.gen::<u64>())
283            .take(m)
284            .collect::<Vec<_>>();
285
286        let n = 300;
287        let b = std::iter::repeat_with(|| table[rng.gen::<usize>() % m])
288            .take(n)
289            .collect::<Vec<_>>();
290
291        let wm = WM64::new(b.clone());
292
293        for _ in 0..1000 {
294            let lr = rand_range(&mut rng, 0..n);
295            let x = table[rng.gen::<usize>() % m];
296
297            let count = b[lr.clone()].iter().filter(|&&y| x == y).count();
298
299            assert_eq!(wm.count(lr, x), count);
300        }
301    }
302
303    #[test]
304    fn test_select() {
305        let mut rng = rand::thread_rng();
306
307        let m = 50;
308        let table = std::iter::repeat_with(|| rng.gen::<u64>())
309            .take(m)
310            .collect::<Vec<_>>();
311
312        let n = 300;
313        let b = std::iter::repeat_with(|| table[rng.gen::<usize>() % m])
314            .take(n)
315            .collect::<Vec<_>>();
316
317        let wm = WM64::new(b.clone());
318
319        for x in table {
320            let count = wm.count(.., x);
321
322            assert_eq!(
323                (0..count)
324                    .map(|i| wm.select(i, x).unwrap())
325                    .collect::<Vec<_>>(),
326                (0..n).filter(|&i| b[i] == x).collect::<Vec<_>>()
327            );
328        }
329    }
330
331    #[test]
332    fn test_quantile() {
333        let mut rng = rand::thread_rng();
334
335        let m = 50;
336        let table = std::iter::repeat_with(|| rng.gen::<u64>())
337            .take(m)
338            .collect::<Vec<_>>();
339
340        let n = 300;
341        let b = std::iter::repeat_with(|| table[rng.gen::<usize>() % m])
342            .take(n)
343            .collect::<Vec<_>>();
344
345        let wm = WM64::new(b.clone());
346
347        for _ in 0..300 {
348            let lr = rand_range(&mut rng, 0..n);
349
350            let mut a = b[lr.clone()].to_vec();
351            a.sort();
352
353            assert_eq!(
354                (0..lr.end - lr.start)
355                    .map(|i| wm.quantile(lr.clone(), i).unwrap())
356                    .collect::<Vec<_>>(),
357                a
358            );
359
360            assert_eq!(wm.maximum(lr.clone()), a.last().copied());
361            assert_eq!(wm.minimum(lr.clone()), a.first().copied());
362        }
363    }
364
365    #[test]
366    fn test_prev_next_value() {
367        let mut rng = rand::thread_rng();
368
369        let m = 50;
370        let table = std::iter::repeat_with(|| rng.gen::<u64>())
371            .take(m)
372            .collect::<Vec<_>>();
373
374        let n = 300;
375        let b = std::iter::repeat_with(|| table[rng.gen::<usize>() % m])
376            .take(n)
377            .collect::<Vec<_>>();
378
379        let wm = WM64::new(b.clone());
380
381        for _ in 0..1000 {
382            let lr = rand_range(&mut rng, 0..n);
383
384            let mut a = b[lr.clone()].to_vec();
385            a.sort();
386
387            let x = rng.gen::<u64>();
388            let i = a.lower_bound(&x);
389
390            assert_eq!(wm.next_value(lr.clone(), x), a.get(i).copied());
391
392            let i = a.lower_bound(&x);
393
394            assert_eq!(
395                wm.prev_value(lr, x),
396                if i == 0 { None } else { a.get(i - 1).copied() }
397            );
398        }
399    }
400
401    #[test]
402    fn test_range_freq() {
403        let mut rng = rand::thread_rng();
404
405        let m = 50;
406        let table = std::iter::repeat_with(|| rng.gen::<u64>())
407            .take(m)
408            .collect::<Vec<_>>();
409
410        let n = 300;
411        let b = std::iter::repeat_with(|| table[rng.gen::<usize>() % m])
412            .take(n)
413            .collect::<Vec<_>>();
414
415        let wm = WM64::new(b.clone());
416
417        for _ in 0..1000 {
418            let lr = rand_range(&mut rng, 0..n);
419            let lb = rng.gen::<u64>();
420            let ub = rng.gen::<u64>();
421
422            assert_eq!(
423                wm.range_freq(lr.clone(), lb, ub),
424                b[lr].iter().filter(|&&x| lb <= x && x < ub).count()
425            );
426        }
427    }
428}