haar_lib/ds/
segtree_2d.rs

1//! 二次元のセグメント木
2use crate::algebra::traits::*;
3use std::ops::Range;
4
5/// 二次元のセグメント木
6pub struct Segtree2D<M: Monoid + Commutative> {
7    data: Vec<Vec<M>>,
8    w: usize,
9    h: usize,
10}
11
12impl<M: Monoid + Commutative + Clone> Segtree2D<M> {
13    /// **Time complexity** $O(wh)$
14    ///
15    /// **Space complexity** $O(wh)$
16    pub fn new(w: usize, h: usize) -> Self {
17        let w = w.next_power_of_two() * 2;
18        let h = h.next_power_of_two() * 2;
19        let data = vec![vec![M::id(); h]; w];
20        Self { data, w, h }
21    }
22
23    fn __fold(&self, l: usize, r: usize, x: usize) -> M {
24        let mut l = l + self.h / 2;
25        let mut r = r + self.h / 2;
26
27        let mut ret = M::id();
28        let a = &self.data[x];
29
30        while l < r {
31            if r & 1 == 1 {
32                r -= 1;
33                ret = M::op(ret, a[r].clone());
34            }
35            if l & 1 == 1 {
36                ret = M::op(ret, a[l].clone());
37                l += 1;
38            }
39            l >>= 1;
40            r >>= 1;
41        }
42
43        ret
44    }
45
46    /// **Time complexity** $O(\log w \log h)$
47    pub fn fold_2d(
48        &self,
49        Range { start: x1, end: x2 }: Range<usize>,
50        Range { start: y1, end: y2 }: Range<usize>,
51    ) -> M {
52        let mut l = x1 + self.w / 2;
53        let mut r = x2 + self.w / 2;
54
55        let mut ret = M::id();
56
57        while l < r {
58            if r & 1 == 1 {
59                r -= 1;
60                ret = M::op(ret, self.__fold(y1, y2, r));
61            }
62            if l & 1 == 1 {
63                ret = M::op(ret, self.__fold(y1, y2, l));
64                l += 1;
65            }
66            l >>= 1;
67            r >>= 1;
68        }
69
70        ret
71    }
72
73    /// **Time complexity** $O(1)$
74    pub fn get(&self, i: usize, j: usize) -> M {
75        self.data[i + self.w / 2][j + self.h / 2].clone()
76    }
77
78    /// **Time complexity** $O(\log w \log h)$
79    pub fn assign(&mut self, i: usize, j: usize, value: M) {
80        let i = i + self.w / 2;
81        let j = j + self.h / 2;
82
83        self.data[i][j] = value;
84
85        let mut x = i >> 1;
86        while x > 0 {
87            self.data[x][j] = M::op(
88                self.data[x << 1][j].clone(),
89                self.data[(x << 1) | 1][j].clone(),
90            );
91            x >>= 1;
92        }
93
94        let mut y = j >> 1;
95        while y > 0 {
96            let mut x = i;
97            while x > 0 {
98                self.data[x][y] = M::op(
99                    self.data[x][y << 1].clone(),
100                    self.data[x][(y << 1) | 1].clone(),
101                );
102                x >>= 1;
103            }
104            y >>= 1;
105        }
106    }
107
108    /// **Time complexity** $O(\log w \log h)$
109    pub fn update(&mut self, i: usize, j: usize, value: M) {
110        let value = M::op(value, self.get(i, j));
111        self.assign(i, j, value);
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::algebra::sum::*;
119    use my_testtools::*;
120    use rand::Rng;
121
122    #[test]
123    fn test() {
124        #![allow(clippy::needless_range_loop)]
125        let w = 300;
126        let h = 100;
127
128        let mut rng = rand::thread_rng();
129
130        let mut seg = Segtree2D::<Sum<u64>>::new(w, h);
131        let mut a = vec![vec![Sum::id(); h]; w];
132
133        for i in 0..w {
134            for j in 0..h {
135                let x = rng.gen::<u64>() % 10000;
136
137                a[i][j] = Sum(x);
138                seg.assign(i, j, Sum(x));
139            }
140        }
141
142        for _ in 0..100 {
143            let i = rng.gen::<usize>() % w;
144            let j = rng.gen::<usize>() % h;
145            let x = rng.gen::<u64>() % 10000;
146
147            seg.assign(i, j, Sum(x));
148            a[i][j] = Sum(x);
149
150            let wr = rand_range(&mut rng, 0..w);
151            let hr = rand_range(&mut rng, 0..h);
152
153            let res = seg.fold_2d(wr.clone(), hr.clone());
154
155            let ans = a[wr]
156                .iter()
157                .map(|a| a[hr.clone()].iter().cloned().fold_m())
158                .fold_m();
159
160            assert_eq!(res, ans);
161        }
162    }
163}