haar_lib/ds/
sparse_table_2d.rs

1//! 冪等性と結合性をもつ2次元列の区間取得($O(1)$)ができる。
2use crate::algebra::traits::*;
3use std::{
4    cmp::{max, min},
5    ops::Range,
6};
7
8/// 冪等性と結合性をもつ2次元列の区間取得($O(1)$)ができる。
9pub struct SparseTable2D<A: Semilattice> {
10    semilattice: A,
11    data: Vec<Vec<Vec<Vec<A::Element>>>>,
12    log_table: Vec<usize>,
13}
14
15impl<A: Semilattice> SparseTable2D<A>
16where
17    A::Element: Clone + Default,
18{
19    /// **Time complexity** $O(nm \log n \log m)$
20    ///
21    /// **Space complexity** $O(nm \log n \log m)$
22    pub fn new(semilattice: A, s: Vec<Vec<A::Element>>) -> Self {
23        let n = s.len();
24        let m = s[0].len();
25        let logn = n.next_power_of_two().trailing_zeros() as usize + 1;
26        let logm = m.next_power_of_two().trailing_zeros() as usize + 1;
27
28        let mut data = vec![vec![vec![vec![A::Element::default(); logm]; m]; logn]; n];
29
30        for i in 0..n {
31            for j in 0..m {
32                data[i][0][j][0] = s[i][j].clone();
33            }
34
35            for y in 1..logm {
36                for j in 0..m {
37                    data[i][0][j][y] = semilattice.op(
38                        data[i][0][j][y - 1].clone(),
39                        data[i][0][min(m - 1, j + (1 << (y - 1)))][y - 1].clone(),
40                    );
41                }
42            }
43        }
44
45        for x in 1..logn {
46            for i in 0..n {
47                for y in 0..logm {
48                    for j in 0..m {
49                        data[i][x][j][y] = semilattice.op(
50                            data[i][x - 1][j][y].clone(),
51                            data[min(n - 1, i + (1 << (x - 1)))][x - 1][j][y].clone(),
52                        );
53                    }
54                }
55            }
56        }
57
58        let mut log_table = vec![0; max(n, m) + 1];
59        for i in 2..=max(n, m) {
60            log_table[i] = log_table[i >> 1] + 1;
61        }
62
63        Self {
64            semilattice,
65            data,
66            log_table,
67        }
68    }
69
70    /// **Time complexity** $O(1)$
71    pub fn fold_2d(
72        &self,
73        Range { start: r1, end: r2 }: Range<usize>,
74        Range { start: c1, end: c2 }: Range<usize>,
75    ) -> Option<A::Element> {
76        if r1 == r2 || c1 == c2 {
77            return None;
78        }
79        let kr = self.log_table[r2 - r1];
80        let kc = self.log_table[c2 - c1];
81
82        let x = self.semilattice.op(
83            self.data[r1][kr][c1][kc].clone(),
84            self.data[r1][kr][c2 - (1 << kc)][kc].clone(),
85        );
86        let y = self.semilattice.op(
87            self.data[r2 - (1 << kr)][kr][c1][kc].clone(),
88            self.data[r2 - (1 << kr)][kr][c2 - (1 << kc)][kc].clone(),
89        );
90
91        Some(self.semilattice.op(x, y))
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use crate::{algebra::min_max::Max, iter::collect::CollectVec};
99    use rand::Rng;
100    use std::fmt::Debug;
101
102    fn test<A>(a: A, s: Vec<Vec<A::Element>>)
103    where
104        A: Semilattice + Identity + Clone,
105        A::Element: Copy + Default + PartialEq + Debug,
106    {
107        let st = SparseTable2D::new(a.clone(), s.clone());
108        let n = s.len();
109        let m = s[0].len();
110
111        for x1 in 0..n {
112            for x2 in x1..=n {
113                for y1 in 0..m {
114                    for y2 in y1..=m {
115                        let ans = &s[x1..x2]
116                            .iter()
117                            .map(|v| v[y1..y2].iter().cloned().fold_m(&a))
118                            .fold_m(&a);
119
120                        assert_eq!(*ans, st.fold_2d(x1..x2, y1..y2).unwrap_or(a.id()));
121                    }
122                }
123            }
124        }
125    }
126
127    #[test]
128    fn test_max() {
129        let mut rng = rand::thread_rng();
130        let n = 30;
131        let m = 30;
132        let s = std::iter::repeat_with(|| {
133            std::iter::repeat_with(|| rng.gen::<u64>())
134                .take(m)
135                .collect_vec()
136        })
137        .take(n)
138        .collect_vec();
139        test(Max::<u64>::new(), s);
140    }
141}