1use crate::{ds::succinct_bitvec::SuccinctBitVec, misc::range::range_bounds_to_range};
3use std::{
4 marker::PhantomData,
5 ops::{BitAnd, BitOrAssign, RangeBounds, Shl, Shr},
6};
7
8#[derive(Clone)]
10pub struct WaveletMatrix<T, const BIT_SIZE: usize> {
11 size: usize,
12 sdict: Vec<SuccinctBitVec>,
13 zero_pos: Vec<usize>,
14 _phantom: PhantomData<T>,
15}
16
17impl<T, const BIT_SIZE: usize> WaveletMatrix<T, BIT_SIZE>
18where
19 T: Shr<usize, Output = T>
20 + Shl<usize, Output = T>
21 + BitAnd<Output = T>
22 + BitOrAssign
23 + From<u8>
24 + Eq
25 + Ord
26 + Copy,
27{
28 pub fn new(mut data: Vec<T>) -> Self {
30 let size = data.len();
31
32 let mut sdict = vec![];
33 let mut zero_pos = vec![];
34
35 for k in 0..BIT_SIZE {
36 let mut left = vec![];
37 let mut right = vec![];
38 let mut s = vec![false; size];
39
40 for i in 0..size {
41 s[i] = (data[i] >> (BIT_SIZE - 1 - k)) & T::from(1) == T::from(1);
42 if s[i] {
43 right.push(data[i]);
44 } else {
45 left.push(data[i]);
46 }
47 }
48
49 sdict.push(SuccinctBitVec::new(s));
50 zero_pos.push(left.len());
51
52 data = left;
53 data.extend(right);
54 }
55
56 Self {
57 size,
58 sdict,
59 zero_pos,
60 _phantom: PhantomData,
61 }
62 }
63
64 pub fn access(&self, index: usize) -> T {
66 let mut ret = T::from(0);
67
68 let mut p = index;
69 for i in 0..BIT_SIZE {
70 let t = self.sdict[i].access(p);
71
72 ret |= T::from(t as u8) << (BIT_SIZE - 1 - i);
73 p = self.sdict[i].rank(p, t == 1) + t as usize * self.zero_pos[i];
74 }
75
76 ret
77 }
78
79 fn rank_(&self, index: usize, value: T) -> (usize, usize) {
80 let mut l = 0;
81 let mut r = index;
82
83 for i in 0..BIT_SIZE {
84 let t = (value >> (BIT_SIZE - 1 - i)) & T::from(1);
85
86 if t == T::from(1) {
87 l = self.sdict[i].rank(l, true) + self.zero_pos[i];
88 r = self.sdict[i].rank(r, true) + self.zero_pos[i];
89 } else {
90 l = self.sdict[i].rank(l, false);
91 r = self.sdict[i].rank(r, false);
92 }
93 }
94
95 (l, r)
96 }
97
98 pub fn rank(&self, index: usize, value: T) -> usize {
100 let (l, r) = self.rank_(index, value);
101 r - l
102 }
103
104 pub fn count(&self, range: impl RangeBounds<usize>, value: T) -> usize {
106 let (l, r) = range_bounds_to_range(range, 0, self.size);
107 self.rank(r, value) - self.rank(l, value)
108 }
109
110 pub fn select(&self, nth: usize, value: T) -> Option<usize> {
112 let nth = nth + 1;
113
114 let (l, r) = self.rank_(self.size, value);
115
116 if r - l < nth {
117 None
118 } else {
119 let mut p = l + nth - 1;
120
121 for i in (0..BIT_SIZE).rev() {
122 let t = (value >> (BIT_SIZE - i - 1)) & T::from(1);
123
124 if t == T::from(1) {
125 p = self.sdict[i].select(p - self.zero_pos[i], true).unwrap();
126 } else {
127 p = self.sdict[i].select(p, false).unwrap();
128 }
129 }
130
131 Some(p)
132 }
133 }
134
135 pub fn quantile(&self, range: impl RangeBounds<usize>, nth: usize) -> Option<T> {
137 let (mut l, mut r) = range_bounds_to_range(range, 0, self.size);
138 if nth >= r - l {
139 return None;
140 }
141
142 let mut nth = nth + 1;
143 let mut ret = T::from(0);
144
145 for (i, sdict) in self.sdict.iter().enumerate() {
146 let count_1 = sdict.count(l..r, true);
147 let count_0 = r - l - count_1;
148
149 let mut t = 0;
150
151 if nth > count_0 {
152 t = 1;
153 ret |= T::from(1) << (BIT_SIZE - i - 1);
154 nth -= count_0;
155 }
156
157 let zeropos = unsafe { self.zero_pos.get_unchecked(i) };
158 l = sdict.rank(l, t == 1) + t as usize * zeropos;
159 r = sdict.rank(r, t == 1) + t as usize * zeropos;
160 }
161
162 Some(ret)
163 }
164
165 pub fn maximum(&self, range: impl RangeBounds<usize>) -> Option<T> {
167 let (l, r) = range_bounds_to_range(range, 0, self.size);
168 if r > l {
169 self.quantile(l..r, r - l - 1)
170 } else {
171 None
172 }
173 }
174
175 pub fn minimum(&self, range: impl RangeBounds<usize>) -> Option<T> {
177 self.quantile(range, 0)
178 }
179
180 fn range_freq_lt(&self, range: impl RangeBounds<usize>, ub: T) -> usize {
181 let (mut l, mut r) = range_bounds_to_range(range, 0, self.size);
182 let mut ret = 0;
183 for i in 0..BIT_SIZE {
184 let t = (ub >> (BIT_SIZE - i - 1)) & T::from(1);
185 if t == T::from(1) {
186 ret += self.sdict[i].count(l..r, false);
187 l = self.sdict[i].rank(l, true) + self.zero_pos[i];
188 r = self.sdict[i].rank(r, true) + self.zero_pos[i];
189 } else {
190 l = self.sdict[i].rank(l, false);
191 r = self.sdict[i].rank(r, false);
192 }
193 }
194 ret
195 }
196
197 pub fn next_value(&self, range: impl RangeBounds<usize> + Clone, lb: T) -> Option<T> {
199 let c = self.range_freq_lt(range.clone(), lb);
200 self.quantile(range, c)
201 }
202
203 pub fn prev_value(&self, range: impl RangeBounds<usize> + Clone, ub: T) -> Option<T> {
205 let c = self.range_freq_lt(range.clone(), ub);
206 if c == 0 {
207 None
208 } else {
209 self.quantile(range, c - 1)
210 }
211 }
212
213 pub fn range_freq(&self, range: impl RangeBounds<usize> + Clone, lb: T, ub: T) -> usize {
215 if lb >= ub {
216 return 0;
217 }
218 self.range_freq_lt(range.clone(), ub) - self.range_freq_lt(range, lb)
219 }
220}
221
222pub type WM64 = WaveletMatrix<u64, 64>;
224pub type WM32 = WaveletMatrix<u32, 32>;
226
227#[cfg(test)]
228mod tests {
229 #![allow(clippy::needless_range_loop)]
230 use super::*;
231 use crate::algo::bsearch_slice::BinarySearch;
232 use my_testtools::*;
233 use rand::Rng;
234
235 #[test]
236 fn test_access() {
237 let mut rng = rand::thread_rng();
238 let n = 10000;
239 let b = (0..n).map(|_| rng.gen::<u64>()).collect::<Vec<_>>();
240
241 let wm = WM64::new(b.clone());
242
243 for i in 0..n {
244 assert_eq!(wm.access(i), b[i]);
245 }
246 }
247
248 #[test]
249 fn test_rank() {
250 let mut rng = rand::thread_rng();
251
252 let m = 50;
253 let table = (0..m).map(|_| rng.gen::<u64>()).collect::<Vec<_>>();
254
255 let n = 300;
256 let b = (0..n)
257 .map(|_| table[rng.gen::<usize>() % m])
258 .collect::<Vec<_>>();
259
260 let wm = WM64::new(b.clone());
261
262 for k in 0..m {
263 let mut count = 0;
264 for i in 0..=n {
265 assert_eq!(wm.rank(i, table[k]), count);
266 if b.get(i) == Some(&table[k]) {
267 count += 1;
268 }
269 }
270 }
271 }
272
273 #[test]
274 fn test_count() {
275 let mut rng = rand::thread_rng();
276
277 let m = 50;
278 let table = (0..m).map(|_| rng.gen::<u64>()).collect::<Vec<_>>();
279
280 let n = 300;
281 let b = (0..n)
282 .map(|_| table[rng.gen::<usize>() % m])
283 .collect::<Vec<_>>();
284
285 let wm = WM64::new(b.clone());
286
287 for _ in 0..1000 {
288 let lr = rand_range(&mut rng, 0..n);
289 let x = table[rng.gen::<usize>() % m];
290
291 let count = b[lr.clone()].iter().filter(|&&y| x == y).count();
292
293 assert_eq!(wm.count(lr, x), count);
294 }
295 }
296
297 #[test]
298 fn test_select() {
299 let mut rng = rand::thread_rng();
300
301 let m = 50;
302 let table = (0..m).map(|_| rng.gen::<u64>()).collect::<Vec<_>>();
303
304 let n = 300;
305 let b = (0..n)
306 .map(|_| table[rng.gen::<usize>() % m])
307 .collect::<Vec<_>>();
308
309 let wm = WM64::new(b.clone());
310
311 for x in table {
312 let count = wm.count(.., x);
313
314 assert_eq!(
315 (0..count)
316 .map(|i| wm.select(i, x).unwrap())
317 .collect::<Vec<_>>(),
318 (0..n).filter(|&i| b[i] == x).collect::<Vec<_>>()
319 );
320 }
321 }
322
323 #[test]
324 fn test_quantile() {
325 let mut rng = rand::thread_rng();
326
327 let m = 50;
328 let table = (0..m).map(|_| rng.gen::<u64>()).collect::<Vec<_>>();
329
330 let n = 300;
331 let b = (0..n)
332 .map(|_| table[rng.gen::<usize>() % m])
333 .collect::<Vec<_>>();
334
335 let wm = WM64::new(b.clone());
336
337 for _ in 0..300 {
338 let lr = rand_range(&mut rng, 0..n);
339
340 let mut a = b[lr.clone()].to_vec();
341 a.sort();
342
343 assert_eq!(
344 (0..lr.end - lr.start)
345 .map(|i| wm.quantile(lr.clone(), i).unwrap())
346 .collect::<Vec<_>>(),
347 a
348 );
349
350 assert_eq!(wm.maximum(lr.clone()), a.last().copied());
351 assert_eq!(wm.minimum(lr.clone()), a.first().copied());
352 }
353 }
354
355 #[test]
356 fn test_prev_next_value() {
357 let mut rng = rand::thread_rng();
358
359 let m = 50;
360 let table = (0..m).map(|_| rng.gen::<u64>()).collect::<Vec<_>>();
361
362 let n = 300;
363 let b = (0..n)
364 .map(|_| table[rng.gen::<usize>() % m])
365 .collect::<Vec<_>>();
366
367 let wm = WM64::new(b.clone());
368
369 for _ in 0..1000 {
370 let lr = rand_range(&mut rng, 0..n);
371
372 let mut a = b[lr.clone()].to_vec();
373 a.sort();
374
375 let x = rng.gen::<u64>();
376 let i = a.lower_bound(&x);
377
378 assert_eq!(wm.next_value(lr.clone(), x), a.get(i).copied());
379
380 let i = a.lower_bound(&x);
381
382 assert_eq!(
383 wm.prev_value(lr, x),
384 if i == 0 { None } else { a.get(i - 1).copied() }
385 );
386 }
387 }
388
389 #[test]
390 fn test_range_freq() {
391 let mut rng = rand::thread_rng();
392
393 let m = 50;
394 let table = (0..m).map(|_| rng.gen::<u64>()).collect::<Vec<_>>();
395
396 let n = 300;
397 let b = (0..n)
398 .map(|_| table[rng.gen::<usize>() % m])
399 .collect::<Vec<_>>();
400
401 let wm = WM64::new(b.clone());
402
403 for _ in 0..1000 {
404 let lr = rand_range(&mut rng, 0..n);
405 let lb = rng.gen::<u64>();
406 let ub = rng.gen::<u64>();
407
408 assert_eq!(
409 wm.range_freq(lr.clone(), lb, ub),
410 b[lr].iter().filter(|&&x| lb <= x && x < ub).count()
411 );
412 }
413 }
414}