haar_lib/ds/
sparse_table_2d.rs1use crate::algebra::traits::*;
3use std::{
4 cmp::{max, min},
5 ops::Range,
6};
7
8pub struct SparseTable2D<A: Semilattice> {
10 semilattice: A,
11 data: Vec<Vec<Vec<Vec<A::Element>>>>,
12 log_table: Vec<usize>,
13}
14
15impl<A: Semilattice> SparseTable2D<A>
16where
17 A::Element: Clone + Default,
18{
19 pub fn new(semilattice: A, s: Vec<Vec<A::Element>>) -> Self {
23 let n = s.len();
24 let m = s[0].len();
25 let logn = n.next_power_of_two().trailing_zeros() as usize + 1;
26 let logm = m.next_power_of_two().trailing_zeros() as usize + 1;
27
28 let mut data = vec![vec![vec![vec![A::Element::default(); logm]; m]; logn]; n];
29
30 for i in 0..n {
31 for j in 0..m {
32 data[i][0][j][0] = s[i][j].clone();
33 }
34
35 for y in 1..logm {
36 for j in 0..m {
37 data[i][0][j][y] = semilattice.op(
38 data[i][0][j][y - 1].clone(),
39 data[i][0][min(m - 1, j + (1 << (y - 1)))][y - 1].clone(),
40 );
41 }
42 }
43 }
44
45 for x in 1..logn {
46 for i in 0..n {
47 for y in 0..logm {
48 for j in 0..m {
49 data[i][x][j][y] = semilattice.op(
50 data[i][x - 1][j][y].clone(),
51 data[min(n - 1, i + (1 << (x - 1)))][x - 1][j][y].clone(),
52 );
53 }
54 }
55 }
56 }
57
58 let mut log_table = vec![0; max(n, m) + 1];
59 for i in 2..=max(n, m) {
60 log_table[i] = log_table[i >> 1] + 1;
61 }
62
63 Self {
64 semilattice,
65 data,
66 log_table,
67 }
68 }
69
70 pub fn fold_2d(
72 &self,
73 Range { start: r1, end: r2 }: Range<usize>,
74 Range { start: c1, end: c2 }: Range<usize>,
75 ) -> Option<A::Element> {
76 if r1 == r2 || c1 == c2 {
77 return None;
78 }
79 let kr = self.log_table[r2 - r1];
80 let kc = self.log_table[c2 - c1];
81
82 let x = self.semilattice.op(
83 self.data[r1][kr][c1][kc].clone(),
84 self.data[r1][kr][c2 - (1 << kc)][kc].clone(),
85 );
86 let y = self.semilattice.op(
87 self.data[r2 - (1 << kr)][kr][c1][kc].clone(),
88 self.data[r2 - (1 << kr)][kr][c2 - (1 << kc)][kc].clone(),
89 );
90
91 Some(self.semilattice.op(x, y))
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::{algebra::min_max::Max, iter::collect::CollectVec};
99 use rand::Rng;
100 use std::fmt::Debug;
101
102 fn test<A>(a: A, s: Vec<Vec<A::Element>>)
103 where
104 A: Semilattice + Identity + Clone,
105 A::Element: Copy + Default + PartialEq + Debug,
106 {
107 let st = SparseTable2D::new(a.clone(), s.clone());
108 let n = s.len();
109 let m = s[0].len();
110
111 for x1 in 0..n {
112 for x2 in x1..=n {
113 for y1 in 0..m {
114 for y2 in y1..=m {
115 let ans = &s[x1..x2]
116 .iter()
117 .map(|v| v[y1..y2].iter().cloned().fold_m(&a))
118 .fold_m(&a);
119
120 assert_eq!(*ans, st.fold_2d(x1..x2, y1..y2).unwrap_or(a.id()));
121 }
122 }
123 }
124 }
125 }
126
127 #[test]
128 fn test_max() {
129 let mut rng = rand::thread_rng();
130 let n = 30;
131 let m = 30;
132 let s = std::iter::repeat_with(|| {
133 std::iter::repeat_with(|| rng.gen::<u64>())
134 .take(m)
135 .collect_vec()
136 })
137 .take(n)
138 .collect_vec();
139 test(Max::<u64>::new(), s);
140 }
141}