haar_lib/ds/
segtree_beats.rs

1//! Segment Tree Beats
2//!
3//! # Problems
4//! - <https://judge.yosupo.jp/problem/range_chmin_chmax_add_range_sum>
5
6use crate::misc::range::range_bounds_to_range;
7use std::cmp::{max, min, Ordering};
8use std::ops::RangeBounds;
9
10#[inline]
11fn lc(i: usize) -> usize {
12    i << 1
13}
14
15#[inline]
16fn rc(i: usize) -> usize {
17    (i << 1) | 1
18}
19
20#[inline]
21fn highest_one(i: u64) -> u32 {
22    assert!(i > 0);
23    63 - i.leading_zeros()
24}
25
26/// Segment Tree Beats
27///
28/// 値を区間加算・区間を最小値で更新・区間を最大値で更新、区間総和・区間最小値・区間最大値をとる操作が可能なデータ構造
29#[derive(Clone, Debug)]
30pub struct SegtreeBeats {
31    hsize: usize,
32    original_size: usize,
33
34    fst_max: Vec<i64>,
35    snd_max: Vec<i64>,
36    max_count: Vec<usize>,
37
38    fst_min: Vec<i64>,
39    snd_min: Vec<i64>,
40    min_count: Vec<usize>,
41
42    sum: Vec<i64>,
43    lazy_add: Vec<i64>,
44}
45
46impl SegtreeBeats {
47    /// 長さ`n`の[`SegtreeBeats`]を生成する。
48    pub fn new(n: usize) -> Self {
49        let size = n.next_power_of_two() * 2;
50
51        Self {
52            hsize: size / 2,
53            original_size: n,
54            fst_max: vec![i64::MIN; size],
55            snd_max: vec![i64::MIN; size],
56            max_count: vec![0; size],
57            fst_min: vec![i64::MAX; size],
58            snd_min: vec![i64::MAX; size],
59            min_count: vec![0; size],
60            sum: vec![0; size],
61            lazy_add: vec![0; size],
62        }
63    }
64
65    fn update_node_max(&mut self, i: usize, x: i64) {
66        self.sum[i] += (x - self.fst_max[i]) * (self.max_count[i] as i64);
67
68        if self.fst_max[i] == self.fst_min[i] {
69            self.fst_min[i] = x;
70        } else if self.fst_max[i] == self.snd_min[i] {
71            self.snd_min[i] = x;
72        }
73
74        self.fst_max[i] = x;
75    }
76
77    fn update_node_min(&mut self, i: usize, x: i64) {
78        self.sum[i] += (x - self.fst_min[i]) * (self.min_count[i] as i64);
79
80        if self.fst_max[i] == self.fst_min[i] {
81            self.fst_max[i] = x;
82        } else if self.snd_max[i] == self.fst_min[i] {
83            self.snd_max[i] = x;
84        }
85
86        self.fst_min[i] = x;
87    }
88
89    fn update_node_add(&mut self, i: usize, x: i64) {
90        let len = self.hsize >> highest_one(i as u64);
91
92        self.sum[i] += x * len as i64;
93
94        self.fst_max[i] += x;
95        if self.snd_max[i] != i64::MIN {
96            self.snd_max[i] += x;
97        }
98
99        self.fst_min[i] += x;
100        if self.snd_min[i] != i64::MAX {
101            self.snd_min[i] += x;
102        }
103
104        self.lazy_add[i] += x;
105    }
106
107    fn propagate(&mut self, i: usize) {
108        if i >= self.hsize {
109            return;
110        }
111
112        if self.lazy_add[i] != 0 {
113            self.update_node_add(lc(i), self.lazy_add[i]);
114            self.update_node_add(rc(i), self.lazy_add[i]);
115            self.lazy_add[i] = 0;
116        }
117
118        if self.fst_max[i] < self.fst_max[lc(i)] {
119            self.update_node_max(lc(i), self.fst_max[i]);
120        }
121        if self.fst_min[i] > self.fst_min[lc(i)] {
122            self.update_node_min(lc(i), self.fst_min[i]);
123        }
124
125        if self.fst_max[i] < self.fst_max[rc(i)] {
126            self.update_node_max(rc(i), self.fst_max[i]);
127        }
128        if self.fst_min[i] > self.fst_min[rc(i)] {
129            self.update_node_min(rc(i), self.fst_min[i]);
130        }
131    }
132
133    fn bottom_up(&mut self, i: usize) {
134        let l = lc(i);
135        let r = rc(i);
136
137        self.sum[i] = self.sum[l] + self.sum[r];
138
139        self.fst_max[i] = max(self.fst_max[l], self.fst_max[r]);
140
141        match self.fst_max[l].cmp(&self.fst_max[r]) {
142            Ordering::Less => {
143                self.max_count[i] = self.max_count[r];
144                self.snd_max[i] = max(self.fst_max[l], self.snd_max[r]);
145            }
146            Ordering::Greater => {
147                self.max_count[i] = self.max_count[l];
148                self.snd_max[i] = max(self.snd_max[l], self.fst_max[r]);
149            }
150            Ordering::Equal => {
151                self.max_count[i] = self.max_count[l] + self.max_count[r];
152                self.snd_max[i] = max(self.snd_max[l], self.snd_max[r]);
153            }
154        }
155
156        self.fst_min[i] = min(self.fst_min[l], self.fst_min[r]);
157
158        match self.fst_min[l].cmp(&self.fst_min[r]) {
159            Ordering::Less => {
160                self.min_count[i] = self.min_count[l];
161                self.snd_min[i] = min(self.snd_min[l], self.fst_min[r]);
162            }
163            Ordering::Greater => {
164                self.min_count[i] = self.min_count[r];
165                self.snd_min[i] = min(self.fst_min[l], self.snd_min[r]);
166            }
167            Ordering::Equal => {
168                self.min_count[i] = self.min_count[l] + self.min_count[r];
169                self.snd_min[i] = min(self.snd_min[l], self.snd_min[r]);
170            }
171        }
172    }
173
174    fn chmin_(&mut self, i: usize, l: usize, r: usize, s: usize, t: usize, x: i64) {
175        if r <= s || t <= l || self.fst_max[i] <= x {
176            return;
177        }
178        if s <= l && r <= t && self.snd_max[i] < x {
179            self.update_node_max(i, x);
180            return;
181        }
182        self.propagate(i);
183        self.chmin_(lc(i), l, (l + r) / 2, s, t, x);
184        self.chmin_(rc(i), (l + r) / 2, r, s, t, x);
185        self.bottom_up(i);
186    }
187
188    /// 区間`range`を値`x`との最小値をとって更新する。
189    pub fn chmin(&mut self, range: impl RangeBounds<usize>, x: i64) {
190        let (start, end) = range_bounds_to_range(range, 0, self.original_size);
191        self.chmin_(1, 0, self.hsize, start, end, x);
192    }
193
194    fn chmax_(&mut self, i: usize, l: usize, r: usize, s: usize, t: usize, x: i64) {
195        if r <= s || t <= l || self.fst_min[i] >= x {
196            return;
197        }
198        if s <= l && r <= t && self.snd_min[i] > x {
199            self.update_node_min(i, x);
200            return;
201        }
202        self.propagate(i);
203        self.chmax_(lc(i), l, (l + r) / 2, s, t, x);
204        self.chmax_(rc(i), (l + r) / 2, r, s, t, x);
205        self.bottom_up(i);
206    }
207
208    /// 区間`range`を値`x`との最大値をとって更新する。
209    pub fn chmax(&mut self, range: impl RangeBounds<usize>, x: i64) {
210        let (start, end) = range_bounds_to_range(range, 0, self.original_size);
211        self.chmax_(1, 0, self.hsize, start, end, x);
212    }
213
214    fn add_(&mut self, i: usize, l: usize, r: usize, s: usize, t: usize, x: i64) {
215        if r <= s || t <= l {
216            return;
217        }
218        if s <= l && r <= t {
219            self.update_node_add(i, x);
220            return;
221        }
222        self.propagate(i);
223        self.add_(lc(i), l, (l + r) / 2, s, t, x);
224        self.add_(rc(i), (l + r) / 2, r, s, t, x);
225        self.bottom_up(i);
226    }
227
228    /// 区間`range`に値`x`を加算する。
229    pub fn add(&mut self, range: impl RangeBounds<usize>, x: i64) {
230        let (start, end) = range_bounds_to_range(range, 0, self.original_size);
231        self.add_(1, 0, self.hsize, start, end, x);
232    }
233
234    fn get_sum_(&mut self, i: usize, l: usize, r: usize, s: usize, t: usize) -> i64 {
235        if r <= s || t <= l {
236            return 0;
237        }
238        if s <= l && r <= t {
239            return self.sum[i];
240        }
241
242        self.propagate(i);
243        self.get_sum_(lc(i), l, (l + r) / 2, s, t) + self.get_sum_(rc(i), (l + r) / 2, r, s, t)
244    }
245
246    /// 区間`range`の総和を返す。
247    pub fn sum(&mut self, range: impl RangeBounds<usize>) -> i64 {
248        let (start, end) = range_bounds_to_range(range, 0, self.original_size);
249        self.get_sum_(1, 0, self.hsize, start, end)
250    }
251
252    fn get_max_(&mut self, i: usize, l: usize, r: usize, s: usize, t: usize) -> i64 {
253        if r <= s || t <= l {
254            return i64::MIN;
255        }
256        if s <= l && r <= t {
257            return self.fst_max[i];
258        }
259        self.propagate(i);
260        max(
261            self.get_max_(lc(i), l, (l + r) / 2, s, t),
262            self.get_max_(rc(i), (l + r) / 2, r, s, t),
263        )
264    }
265
266    /// 区間`range`の最大値を返す。
267    pub fn max(&mut self, range: impl RangeBounds<usize>) -> i64 {
268        let (start, end) = range_bounds_to_range(range, 0, self.original_size);
269        self.get_max_(1, 0, self.hsize, start, end)
270    }
271
272    fn get_min_(&mut self, i: usize, l: usize, r: usize, s: usize, t: usize) -> i64 {
273        if r <= s || t <= l {
274            return i64::MAX;
275        }
276        if s <= l && r <= t {
277            return self.fst_min[i];
278        }
279        self.propagate(i);
280        min(
281            self.get_min_(lc(i), l, (l + r) / 2, s, t),
282            self.get_min_(rc(i), (l + r) / 2, r, s, t),
283        )
284    }
285
286    /// 区間`range`の最小値を返す。
287    pub fn min(&mut self, range: impl RangeBounds<usize>) -> i64 {
288        let (start, end) = range_bounds_to_range(range, 0, self.original_size);
289        self.get_min_(1, 0, self.hsize, start, end)
290    }
291}
292
293impl From<Vec<i64>> for SegtreeBeats {
294    fn from(value: Vec<i64>) -> Self {
295        let mut ret = Self::new(value.len());
296        let hsize = ret.hsize;
297
298        for (i, x) in value.into_iter().enumerate() {
299            ret.fst_max[hsize + i] = x;
300            ret.max_count[hsize + i] = 1;
301            ret.fst_min[hsize + i] = x;
302            ret.min_count[hsize + i] = 1;
303            ret.sum[hsize + i] = x;
304        }
305
306        for i in (1..hsize).rev() {
307            ret.bottom_up(i);
308        }
309
310        ret
311    }
312}
313
314#[cfg(test)]
315mod test {
316    use super::*;
317    use my_testtools::*;
318    use rand::Rng;
319
320    #[test]
321    fn test() {
322        let mut rng = rand::thread_rng();
323
324        let n = 1000;
325        let limit = 1000000000;
326
327        let mut a = vec![0; n];
328        let mut seg = SegtreeBeats::from(a.clone());
329
330        for _ in 0..10000 {
331            match rng.gen_range(0..=5) {
332                0 => {
333                    let lr = rand_range(&mut rng, 0..n);
334                    let x = rng.gen_range(-limit..=limit);
335                    seg.chmax(lr.clone(), x);
336                    a[lr].iter_mut().for_each(|y| *y = std::cmp::max(x, *y));
337                }
338                1 => {
339                    let lr = rand_range(&mut rng, 0..n);
340                    let x = rng.gen_range(-limit..=limit);
341                    seg.chmin(lr.clone(), x);
342                    a[lr].iter_mut().for_each(|y| *y = std::cmp::min(x, *y));
343                }
344                2 => {
345                    let lr = rand_range(&mut rng, 0..n);
346                    let x = rng.gen_range(-limit..=limit);
347                    seg.add(lr.clone(), x);
348                    a[lr].iter_mut().for_each(|y| *y += x);
349                }
350                3 => {
351                    let lr = rand_range(&mut rng, 0..n);
352                    assert_eq!(seg.sum(lr.clone()), a[lr].iter().sum());
353                }
354                4 => {
355                    let lr = rand_range(&mut rng, 0..n);
356                    assert_eq!(
357                        seg.max(lr.clone()),
358                        a[lr].iter().max().copied().unwrap_or(i64::MIN)
359                    );
360                }
361                5 => {
362                    let lr = rand_range(&mut rng, 0..n);
363                    assert_eq!(
364                        seg.min(lr.clone()),
365                        a[lr].iter().min().copied().unwrap_or(i64::MAX)
366                    );
367                }
368
369                _ => unreachable!(),
370            }
371        }
372    }
373}