haar_lib/ds/
segtree.rs

1//! モノイド列の点更新・区間取得($O(\log n)$, $O(\log n)$)ができる。
2pub use crate::algebra::traits::Monoid;
3use crate::misc::range::range_bounds_to_range;
4use std::ops::{Index, RangeBounds};
5
6/// モノイド列の点更新・区間取得($O(\log n)$, $O(\log n)$)ができる。
7#[derive(Clone)]
8pub struct Segtree<M: Monoid> {
9    monoid: M,
10    original_size: usize,
11    size: usize,
12    data: Vec<M::Element>,
13}
14
15impl<M: Monoid> Segtree<M>
16where
17    M::Element: Clone,
18{
19    /// **Time complexity** $O(n)$
20    pub fn new(monoid: M, n: usize) -> Self {
21        let size = n.next_power_of_two() * 2;
22        Self {
23            original_size: n,
24            size,
25            data: vec![monoid.id(); size],
26            monoid,
27        }
28    }
29
30    /// モノイド列から`Segtree`を構築する。
31    ///
32    /// **Time complexity** $O(|s|)$
33    pub fn from_vec(monoid: M, s: Vec<M::Element>) -> Self {
34        let mut this = Self::new(monoid, s.len());
35
36        for (i, x) in s.iter().enumerate() {
37            this.data[i + this.size / 2] = x.clone();
38        }
39
40        for i in (1..this.size / 2).rev() {
41            this.data[i] = this
42                .monoid
43                .op(this.data[i << 1].clone(), this.data[(i << 1) | 1].clone());
44        }
45
46        this
47    }
48
49    /// モノイド列をスライスで返す。
50    pub fn to_slice(&self) -> &[M::Element] {
51        &self.data[self.size / 2..self.size / 2 + self.original_size]
52    }
53
54    /// **Time complexity** $O(\log n)$
55    pub fn fold<R: RangeBounds<usize>>(&self, range: R) -> M::Element {
56        let (l, r) = range_bounds_to_range(range, 0, self.size / 2);
57
58        let mut ret_l = self.monoid.id();
59        let mut ret_r = self.monoid.id();
60
61        let mut l = l + self.size / 2;
62        let mut r = r + self.size / 2;
63
64        while l < r {
65            if r & 1 == 1 {
66                r -= 1;
67                ret_r = self.monoid.op(self.data[r].clone(), ret_r);
68            }
69            if l & 1 == 1 {
70                ret_l = self.monoid.op(ret_l, self.data[l].clone());
71                l += 1;
72            }
73            r >>= 1;
74            l >>= 1;
75        }
76
77        self.monoid.op(ret_l, ret_r)
78    }
79
80    /// **Time complexity** $O(\log n)$
81    pub fn assign(&mut self, i: usize, value: M::Element) {
82        let mut i = i + self.size / 2;
83        self.data[i] = value;
84
85        while i > 1 {
86            i >>= 1;
87            self.data[i] = self
88                .monoid
89                .op(self.data[i << 1].clone(), self.data[(i << 1) | 1].clone());
90        }
91    }
92
93    /// **Time complexity** $O(\log n)$
94    pub fn update(&mut self, i: usize, value: M::Element) {
95        self.assign(
96            i,
97            self.monoid.op(self.data[i + self.size / 2].clone(), value),
98        );
99    }
100}
101
102impl<M: Monoid> From<&Segtree<M>> for Vec<M::Element>
103where
104    M::Element: Clone,
105{
106    fn from(from: &Segtree<M>) -> Self {
107        from.data[from.size / 2..from.size / 2 + from.original_size].to_vec()
108    }
109}
110
111impl<M: Monoid> Index<usize> for Segtree<M> {
112    type Output = M::Element;
113
114    fn index(&self, i: usize) -> &Self::Output {
115        &self.data[self.size / 2 + i]
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use crate::algebra::bit::BitXor;
123    use crate::algebra::matrix::ProdMatrix;
124    use crate::algebra::min_max::{Max, Min};
125    use crate::algebra::semiring::add_mul_mod::AddMulMod;
126    use crate::algebra::sum::Sum;
127    use crate::algebra::traits::*;
128    use crate::linalg::matrix::MatrixOnSemiring;
129    use crate::num::{ff::*, modint::ModIntBuilder};
130
131    use my_testtools::*;
132    use rand::Rng;
133
134    fn random_test_helper<M, F>(monoid: M, size: usize, mut gen_value: F)
135    where
136        M: Monoid + Clone,
137        M::Element: Clone + Eq + std::fmt::Debug,
138        F: FnMut() -> M::Element,
139    {
140        let mut rng = rand::thread_rng();
141
142        let mut other = vec![monoid.id(); size];
143        let mut s = Segtree::new(monoid.clone(), size);
144
145        for _ in 0..1000 {
146            let ty = rng.gen_range(0..2);
147
148            if ty == 0 {
149                let i = rng.gen_range(0..size);
150                let x = gen_value();
151
152                other[i] = monoid.op(other[i].clone(), x.clone());
153                s.update(i, x);
154            } else {
155                let lr = rand_range(&mut rng, 0..size);
156
157                let ans = other[lr.clone()].iter().cloned().fold_m(&monoid);
158
159                assert_eq!(s.fold(lr), ans);
160            }
161
162            let i = rng.gen_range(0..size);
163            assert_eq!(s[i], other[i]);
164        }
165
166        assert_eq!(Vec::from(&s), other);
167    }
168
169    #[test]
170    fn test_sum() {
171        let mut rng = rand::thread_rng();
172        random_test_helper(Sum::<i32>::new(), 10, || rng.gen::<i32>() % 10000);
173    }
174
175    #[test]
176    fn test_xor() {
177        let mut rng = rand::thread_rng();
178        random_test_helper(BitXor::<u32>::new(), 10, || rng.gen::<u32>() % 10000);
179    }
180
181    #[test]
182    fn test_min() {
183        let mut rng = rand::thread_rng();
184        random_test_helper(Min::<i32>::new(), 10, || rng.gen::<i32>() % 10000);
185    }
186
187    #[test]
188    fn test_max() {
189        let mut rng = rand::thread_rng();
190        random_test_helper(Max::<i32>::new(), 10, || rng.gen::<i32>() % 10000);
191    }
192
193    #[test]
194    fn test_matrix_prod() {
195        let mut rng = rand::thread_rng();
196
197        let n = 10;
198
199        let modulo = ModIntBuilder::new(10_u32.pow(9) + 7);
200        let ring = AddMulMod(modulo);
201        let monoid = ProdMatrix::new(ring, n);
202
203        random_test_helper(monoid, 100, || {
204            let mut a = MatrixOnSemiring::zero(ring, n, n);
205            for i in 0..n {
206                for j in 0..n {
207                    *a.get_mut(i, j).unwrap() = modulo.from_u64(rng.gen());
208                }
209            }
210            a
211        });
212    }
213}