haar_lib/ds/
lazy_segtree.rs

1//! モノイド列の区間更新・区間取得($O(\log n)$, $O(\log n)$)ができる。
2use crate::algebra::action::Action;
3use crate::misc::range::range_bounds_to_range;
4use std::ops::RangeBounds;
5
6/// モノイド列の区間更新・区間取得($O(\log n)$, $O(\log n)$)ができる。
7pub struct LazySegtree<A: Action> {
8    size: usize,
9    original_size: usize,
10    data: Vec<A::Output>,
11    lazy: Vec<A::Lazy>,
12}
13
14impl<A: Action> LazySegtree<A>
15where
16    A::Output: Clone + PartialEq,
17    A::Lazy: Clone + PartialEq,
18{
19    /// 長さ`n`の[`LazySegtree`]を生成する。
20    pub fn new(n: usize) -> Self {
21        let size = n.next_power_of_two() * 2;
22        Self {
23            size,
24            original_size: n,
25            data: vec![A::fold_id(); size],
26            lazy: vec![A::update_id(); size],
27        }
28    }
29
30    /// [`Vec`]から[`LazySegtree`]を構築する。
31    ///
32    /// **Time complexity** $O(|s|)$
33    pub fn from_vec(s: Vec<A::Output>) -> Self {
34        let n = s.len();
35        let size = n.next_power_of_two() * 2;
36        let mut this = Self {
37            size,
38            original_size: n,
39            data: vec![A::fold_id(); size],
40            lazy: vec![A::update_id(); size],
41        };
42
43        for (i, x) in s.into_iter().enumerate() {
44            this.data[size / 2 + i] = x;
45        }
46
47        for i in (1..size / 2).rev() {
48            this.data[i] = A::fold(this.data[i << 1].clone(), this.data[(i << 1) | 1].clone());
49        }
50
51        this
52    }
53
54    /// 遅延操作を完了させたモノイド列をスライスで返す。
55    ///
56    /// **Time complexity** $O(n)$
57    pub fn to_slice(&mut self) -> &[A::Output] {
58        for i in 1..self.size {
59            self.propagate(i);
60        }
61
62        &self.data[self.size / 2..self.size / 2 + self.original_size]
63    }
64
65    fn propagate(&mut self, i: usize) {
66        if self.lazy[i] == A::update_id() {
67            return;
68        }
69        if i < self.size / 2 {
70            let l = i << 1;
71            let r = (i << 1) | 1;
72
73            self.lazy[l] = A::update(self.lazy[l].clone(), self.lazy[i].clone());
74            self.lazy[r] = A::update(self.lazy[r].clone(), self.lazy[i].clone());
75        }
76        let len = (self.size / 2) >> (31 - (i as u32).leading_zeros());
77        self.data[i] = A::convert(self.data[i].clone(), self.lazy[i].clone(), len);
78        self.lazy[i] = A::update_id();
79    }
80
81    fn propagate_top_down(&mut self, mut i: usize) {
82        let mut temp = vec![i];
83        while i > 1 {
84            i >>= 1;
85            temp.push(i);
86        }
87
88        for i in temp.into_iter().rev() {
89            self.propagate(i);
90        }
91    }
92
93    fn bottom_up(&mut self, mut i: usize) {
94        while i > 1 {
95            i >>= 1;
96            self.propagate(i << 1);
97            self.propagate((i << 1) | 1);
98            self.data[i] = A::fold(self.data[i << 1].clone(), self.data[(i << 1) | 1].clone());
99        }
100    }
101
102    /// `i`番目の値を返す。
103    pub fn get(&mut self, i: usize) -> A::Output {
104        self.propagate_top_down(i + self.size / 2);
105        self.data[i + self.size / 2].clone()
106    }
107
108    /// 区間`range`で計算を集約して返す。
109    pub fn fold(&mut self, range: impl RangeBounds<usize>) -> A::Output {
110        let (l, r) = range_bounds_to_range(range, 0, self.original_size);
111
112        self.propagate_top_down(l + self.size / 2);
113        if r < self.size / 2 {
114            self.propagate_top_down(r + self.size / 2);
115        }
116
117        let mut ret_l = A::fold_id();
118        let mut ret_r = A::fold_id();
119
120        let mut l = l + self.size / 2;
121        let mut r = r + self.size / 2;
122
123        while l < r {
124            if r & 1 == 1 {
125                r -= 1;
126                self.propagate(r);
127                ret_r = A::fold(self.data[r].clone(), ret_r.clone());
128            }
129            if l & 1 == 1 {
130                self.propagate(l);
131                ret_l = A::fold(ret_l.clone(), self.data[l].clone());
132                l += 1;
133            }
134            r >>= 1;
135            l >>= 1;
136        }
137
138        A::fold(ret_l, ret_r)
139    }
140
141    /// `i`番目の値を`value`で置き換える。
142    pub fn assign(&mut self, i: usize, value: A::Output) {
143        self.propagate_top_down(i + self.size / 2);
144        self.data[i + self.size / 2] = value;
145        self.bottom_up(i + self.size / 2);
146    }
147
148    /// 区間`range`を値`x`で更新する。
149    pub fn update(&mut self, range: impl RangeBounds<usize>, x: A::Lazy) {
150        let (l, r) = range_bounds_to_range(range, 0, self.original_size);
151
152        self.propagate_top_down(l + self.size / 2);
153        if r < self.size / 2 {
154            self.propagate_top_down(r + self.size / 2);
155        }
156
157        {
158            let mut l = l + self.size / 2;
159            let mut r = r + self.size / 2;
160
161            while l < r {
162                if r & 1 == 1 {
163                    r -= 1;
164                    self.lazy[r] = A::update(self.lazy[r].clone(), x.clone());
165                }
166                if l & 1 == 1 {
167                    self.lazy[l] = A::update(self.lazy[l].clone(), x.clone());
168                    l += 1;
169                }
170                r >>= 1;
171                l >>= 1;
172            }
173        }
174
175        self.bottom_up(l + self.size / 2);
176        if r < self.size / 2 {
177            self.bottom_up(r + self.size / 2);
178        }
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use crate::algebra::add_sum::*;
186    use crate::algebra::sum::*;
187    use my_testtools::*;
188    use rand::Rng;
189
190    #[test]
191    fn add_sum() {
192        let n = 100;
193        let q = 100;
194        let range = 1000;
195
196        let mut seg = LazySegtree::<AddSum<u64>>::new(n);
197        let mut vec = vec![Sum::id(); n];
198
199        let mut rng = rand::thread_rng();
200
201        for _ in 0..q {
202            let lr = rand_range(&mut rng, 0..n);
203
204            match rng.gen::<u32>() % 2 {
205                0 => {
206                    let x = rng.gen_range(0..range);
207
208                    seg.update(lr.clone(), Sum(x));
209                    vec[lr].iter_mut().for_each(|y| y.op_assign_r(Sum(x)));
210                }
211                1 => {
212                    assert_eq!(seg.fold(lr.clone()), vec[lr].iter().cloned().fold_m());
213                }
214                _ => unreachable!(),
215            }
216        }
217    }
218}