1use crate::{ds::succinct_bitvec::SuccinctBitVec, misc::range::range_bounds_to_range};
3use std::{
4 marker::PhantomData,
5 ops::{BitAnd, BitOrAssign, RangeBounds, Shl, Shr},
6};
7
8#[derive(Clone)]
10pub struct WaveletMatrix<T, const BIT_SIZE: usize> {
11 size: usize,
12 sdict: Vec<SuccinctBitVec>,
13 zero_pos: Vec<usize>,
14 _phantom: PhantomData<T>,
15}
16
17impl<T, const BIT_SIZE: usize> WaveletMatrix<T, BIT_SIZE>
18where
19 T: Shr<usize, Output = T>
20 + Shl<usize, Output = T>
21 + BitAnd<Output = T>
22 + BitOrAssign
23 + From<u8>
24 + Eq
25 + Ord
26 + Copy,
27{
28 pub fn new(mut data: Vec<T>) -> Self {
30 let size = data.len();
31
32 let mut sdict = vec![];
33 let mut zero_pos = vec![];
34
35 for k in 0..BIT_SIZE {
36 let mut left = vec![];
37 let mut right = vec![];
38 let mut s = vec![false; size];
39
40 for i in 0..size {
41 s[i] = (data[i] >> (BIT_SIZE - 1 - k)) & T::from(1) == T::from(1);
42 if s[i] {
43 right.push(data[i]);
44 } else {
45 left.push(data[i]);
46 }
47 }
48
49 sdict.push(SuccinctBitVec::new(s));
50 zero_pos.push(left.len());
51
52 data = left;
53 data.extend(right);
54 }
55
56 Self {
57 size,
58 sdict,
59 zero_pos,
60 _phantom: PhantomData,
61 }
62 }
63
64 pub fn access(&self, index: usize) -> T {
66 let mut ret = T::from(0);
67
68 let mut p = index;
69 for i in 0..BIT_SIZE {
70 let t = self.sdict[i].access(p);
71
72 ret |= T::from(t as u8) << (BIT_SIZE - 1 - i);
73 p = self.sdict[i].rank(p, t == 1) + t as usize * self.zero_pos[i];
74 }
75
76 ret
77 }
78
79 fn rank_(&self, index: usize, value: T) -> (usize, usize) {
80 let mut l = 0;
81 let mut r = index;
82
83 for i in 0..BIT_SIZE {
84 let t = (value >> (BIT_SIZE - 1 - i)) & T::from(1);
85
86 if t == T::from(1) {
87 l = self.sdict[i].rank(l, true) + self.zero_pos[i];
88 r = self.sdict[i].rank(r, true) + self.zero_pos[i];
89 } else {
90 l = self.sdict[i].rank(l, false);
91 r = self.sdict[i].rank(r, false);
92 }
93 }
94
95 (l, r)
96 }
97
98 pub fn rank(&self, index: usize, value: T) -> usize {
100 let (l, r) = self.rank_(index, value);
101 r - l
102 }
103
104 pub fn count(&self, range: impl RangeBounds<usize>, value: T) -> usize {
106 let (l, r) = range_bounds_to_range(range, 0, self.size);
107 self.rank(r, value) - self.rank(l, value)
108 }
109
110 pub fn select(&self, nth: usize, value: T) -> Option<usize> {
112 let nth = nth + 1;
113
114 let (l, r) = self.rank_(self.size, value);
115
116 if r - l < nth {
117 None
118 } else {
119 let mut p = l + nth - 1;
120
121 for i in (0..BIT_SIZE).rev() {
122 let t = (value >> (BIT_SIZE - i - 1)) & T::from(1);
123
124 if t == T::from(1) {
125 p = self.sdict[i].select(p - self.zero_pos[i], true).unwrap();
126 } else {
127 p = self.sdict[i].select(p, false).unwrap();
128 }
129 }
130
131 Some(p)
132 }
133 }
134
135 pub fn quantile(&self, range: impl RangeBounds<usize>, nth: usize) -> Option<T> {
137 let (mut l, mut r) = range_bounds_to_range(range, 0, self.size);
138 if nth >= r - l {
139 return None;
140 }
141
142 let mut nth = nth + 1;
143 let mut ret = T::from(0);
144
145 for (i, sdict) in self.sdict.iter().enumerate() {
146 let count_1 = sdict.count(l..r, true);
147 let count_0 = r - l - count_1;
148
149 let mut t = 0;
150
151 if nth > count_0 {
152 t = 1;
153 ret |= T::from(1) << (BIT_SIZE - i - 1);
154 nth -= count_0;
155 }
156
157 let zeropos = unsafe { self.zero_pos.get_unchecked(i) };
158 l = sdict.rank(l, t == 1) + t as usize * zeropos;
159 r = sdict.rank(r, t == 1) + t as usize * zeropos;
160 }
161
162 Some(ret)
163 }
164
165 pub fn maximum(&self, range: impl RangeBounds<usize>) -> Option<T> {
167 let (l, r) = range_bounds_to_range(range, 0, self.size);
168 if r > l {
169 self.quantile(l..r, r - l - 1)
170 } else {
171 None
172 }
173 }
174
175 pub fn minimum(&self, range: impl RangeBounds<usize>) -> Option<T> {
177 self.quantile(range, 0)
178 }
179
180 fn range_freq_lt(&self, range: impl RangeBounds<usize>, ub: T) -> usize {
181 let (mut l, mut r) = range_bounds_to_range(range, 0, self.size);
182 let mut ret = 0;
183 for i in 0..BIT_SIZE {
184 let t = (ub >> (BIT_SIZE - i - 1)) & T::from(1);
185 if t == T::from(1) {
186 ret += self.sdict[i].count(l..r, false);
187 l = self.sdict[i].rank(l, true) + self.zero_pos[i];
188 r = self.sdict[i].rank(r, true) + self.zero_pos[i];
189 } else {
190 l = self.sdict[i].rank(l, false);
191 r = self.sdict[i].rank(r, false);
192 }
193 }
194 ret
195 }
196
197 pub fn next_value(&self, range: impl RangeBounds<usize> + Clone, lb: T) -> Option<T> {
199 let c = self.range_freq_lt(range.clone(), lb);
200 self.quantile(range, c)
201 }
202
203 pub fn prev_value(&self, range: impl RangeBounds<usize> + Clone, ub: T) -> Option<T> {
205 let c = self.range_freq_lt(range.clone(), ub);
206 if c == 0 {
207 None
208 } else {
209 self.quantile(range, c - 1)
210 }
211 }
212
213 pub fn range_freq(&self, range: impl RangeBounds<usize> + Clone, lb: T, ub: T) -> usize {
215 if lb >= ub {
216 return 0;
217 }
218 self.range_freq_lt(range.clone(), ub) - self.range_freq_lt(range, lb)
219 }
220}
221
222pub type WM64 = WaveletMatrix<u64, 64>;
224pub type WM32 = WaveletMatrix<u32, 32>;
226
227#[cfg(test)]
228mod tests {
229 #![allow(clippy::needless_range_loop)]
230 use super::*;
231 use crate::algo::bsearch_slice::BinarySearch;
232 use my_testtools::*;
233 use rand::Rng;
234
235 #[test]
236 fn test_access() {
237 let mut rng = rand::thread_rng();
238 let n = 10000;
239 let b = std::iter::repeat_with(|| rng.gen::<u64>())
240 .take(n)
241 .collect::<Vec<_>>();
242
243 let wm = WM64::new(b.clone());
244
245 for i in 0..n {
246 assert_eq!(wm.access(i), b[i]);
247 }
248 }
249
250 #[test]
251 fn test_rank() {
252 let mut rng = rand::thread_rng();
253
254 let m = 50;
255 let table = std::iter::repeat_with(|| rng.gen::<u64>())
256 .take(m)
257 .collect::<Vec<_>>();
258
259 let n = 300;
260 let b = std::iter::repeat_with(|| table[rng.gen::<usize>() % m])
261 .take(n)
262 .collect::<Vec<_>>();
263
264 let wm = WM64::new(b.clone());
265
266 for k in 0..m {
267 let mut count = 0;
268 for i in 0..=n {
269 assert_eq!(wm.rank(i, table[k]), count);
270 if b.get(i) == Some(&table[k]) {
271 count += 1;
272 }
273 }
274 }
275 }
276
277 #[test]
278 fn test_count() {
279 let mut rng = rand::thread_rng();
280
281 let m = 50;
282 let table = std::iter::repeat_with(|| rng.gen::<u64>())
283 .take(m)
284 .collect::<Vec<_>>();
285
286 let n = 300;
287 let b = std::iter::repeat_with(|| table[rng.gen::<usize>() % m])
288 .take(n)
289 .collect::<Vec<_>>();
290
291 let wm = WM64::new(b.clone());
292
293 for _ in 0..1000 {
294 let lr = rand_range(&mut rng, 0..n);
295 let x = table[rng.gen::<usize>() % m];
296
297 let count = b[lr.clone()].iter().filter(|&&y| x == y).count();
298
299 assert_eq!(wm.count(lr, x), count);
300 }
301 }
302
303 #[test]
304 fn test_select() {
305 let mut rng = rand::thread_rng();
306
307 let m = 50;
308 let table = std::iter::repeat_with(|| rng.gen::<u64>())
309 .take(m)
310 .collect::<Vec<_>>();
311
312 let n = 300;
313 let b = std::iter::repeat_with(|| table[rng.gen::<usize>() % m])
314 .take(n)
315 .collect::<Vec<_>>();
316
317 let wm = WM64::new(b.clone());
318
319 for x in table {
320 let count = wm.count(.., x);
321
322 assert_eq!(
323 (0..count)
324 .map(|i| wm.select(i, x).unwrap())
325 .collect::<Vec<_>>(),
326 (0..n).filter(|&i| b[i] == x).collect::<Vec<_>>()
327 );
328 }
329 }
330
331 #[test]
332 fn test_quantile() {
333 let mut rng = rand::thread_rng();
334
335 let m = 50;
336 let table = std::iter::repeat_with(|| rng.gen::<u64>())
337 .take(m)
338 .collect::<Vec<_>>();
339
340 let n = 300;
341 let b = std::iter::repeat_with(|| table[rng.gen::<usize>() % m])
342 .take(n)
343 .collect::<Vec<_>>();
344
345 let wm = WM64::new(b.clone());
346
347 for _ in 0..300 {
348 let lr = rand_range(&mut rng, 0..n);
349
350 let mut a = b[lr.clone()].to_vec();
351 a.sort();
352
353 assert_eq!(
354 (0..lr.end - lr.start)
355 .map(|i| wm.quantile(lr.clone(), i).unwrap())
356 .collect::<Vec<_>>(),
357 a
358 );
359
360 assert_eq!(wm.maximum(lr.clone()), a.last().copied());
361 assert_eq!(wm.minimum(lr.clone()), a.first().copied());
362 }
363 }
364
365 #[test]
366 fn test_prev_next_value() {
367 let mut rng = rand::thread_rng();
368
369 let m = 50;
370 let table = std::iter::repeat_with(|| rng.gen::<u64>())
371 .take(m)
372 .collect::<Vec<_>>();
373
374 let n = 300;
375 let b = std::iter::repeat_with(|| table[rng.gen::<usize>() % m])
376 .take(n)
377 .collect::<Vec<_>>();
378
379 let wm = WM64::new(b.clone());
380
381 for _ in 0..1000 {
382 let lr = rand_range(&mut rng, 0..n);
383
384 let mut a = b[lr.clone()].to_vec();
385 a.sort();
386
387 let x = rng.gen::<u64>();
388 let i = a.lower_bound(&x);
389
390 assert_eq!(wm.next_value(lr.clone(), x), a.get(i).copied());
391
392 let i = a.lower_bound(&x);
393
394 assert_eq!(
395 wm.prev_value(lr, x),
396 if i == 0 { None } else { a.get(i - 1).copied() }
397 );
398 }
399 }
400
401 #[test]
402 fn test_range_freq() {
403 let mut rng = rand::thread_rng();
404
405 let m = 50;
406 let table = std::iter::repeat_with(|| rng.gen::<u64>())
407 .take(m)
408 .collect::<Vec<_>>();
409
410 let n = 300;
411 let b = std::iter::repeat_with(|| table[rng.gen::<usize>() % m])
412 .take(n)
413 .collect::<Vec<_>>();
414
415 let wm = WM64::new(b.clone());
416
417 for _ in 0..1000 {
418 let lr = rand_range(&mut rng, 0..n);
419 let lb = rng.gen::<u64>();
420 let ub = rng.gen::<u64>();
421
422 assert_eq!(
423 wm.range_freq(lr.clone(), lb, ub),
424 b[lr].iter().filter(|&&x| lb <= x && x < ub).count()
425 );
426 }
427 }
428}