haar_lib/ds/
segtree_linear_add_range_sum.rs

1//! 区間一次関数加算区間総和セグメントツリー
2
3use crate::math::linear::*;
4use crate::misc::range::range_bounds_to_range;
5use crate::num::one_zero::Zero;
6use std::ops::{Add, AddAssign, Mul, RangeBounds};
7
8/// 区間一次関数加算区間総和セグメントツリー
9pub struct SegtreeLinearAddRangeSum<T> {
10    data: Vec<T>,
11    lazy: Vec<(T, T)>,
12    hsize: usize,
13    original_size: usize,
14}
15
16impl<T> SegtreeLinearAddRangeSum<T>
17where
18    T: Copy + Zero + Add<Output = T> + Mul<Output = T> + AddAssign + PartialEq + From<u32>,
19{
20    /// **Time complexity** $O(n)$
21    ///
22    /// **Space complexity** $O(n)$
23    pub fn new(n: usize) -> Self {
24        let hsize = n.next_power_of_two();
25
26        Self {
27            data: vec![T::zero(); hsize * 2],
28            lazy: vec![(T::zero(), T::zero()); hsize * 2],
29            hsize,
30            original_size: n,
31        }
32    }
33
34    fn _add(a: (T, T), b: (T, T)) -> (T, T) {
35        (a.0 + b.0, a.1 + b.1)
36    }
37
38    fn _propagate(&mut self, i: usize, l: usize, r: usize) {
39        if self.lazy[i] == (T::zero(), T::zero()) {
40            return;
41        }
42        if i < self.hsize {
43            let mut t = self.lazy[i];
44            self.lazy[i << 1] = Self::_add(t, self.lazy[i << 1]);
45            t.0 += t.1 * T::from(((r - l) / 2) as u32);
46            self.lazy[(i << 1) | 1] = Self::_add(t, self.lazy[(i << 1) | 1]);
47        }
48        let len = r - l;
49        let (s, d) = self.lazy[i];
50
51        self.data[i] += s * T::from(len as u32) + d * T::from(((len - 1) * len / 2) as u32);
52        self.lazy[i] = (T::zero(), T::zero());
53    }
54
55    fn _update(&mut self, i: usize, l: usize, r: usize, s: usize, t: usize, a: T, b: T) -> T {
56        self._propagate(i, l, r);
57        if r <= s || t <= l {
58            self.data[i]
59        } else if s <= l && r <= t {
60            self.lazy[i] = Self::_add(self.lazy[i], (a * T::from(l as u32) + b, a));
61            self._propagate(i, l, r);
62            self.data[i]
63        } else {
64            let mid = (l + r) / 2;
65            self.data[i] = self._update(i << 1, l, mid, s, t, a, b)
66                + self._update((i << 1) | 1, mid, r, s, t, a, b);
67            self.data[i]
68        }
69    }
70
71    /// **Time complexity** $O(\log n)$
72    pub fn update(&mut self, range: impl RangeBounds<usize>, linear: Linear<T>) {
73        let (start, end) = range_bounds_to_range(range, 0, self.original_size);
74        self._update(1, 0, self.hsize, start, end, linear.a, linear.b);
75    }
76
77    fn _fold(&mut self, i: usize, l: usize, r: usize, x: usize, y: usize) -> T {
78        self._propagate(i, l, r);
79        if r <= x || y <= l {
80            T::zero()
81        } else if x <= l && r <= y {
82            self.data[i]
83        } else {
84            let mid = (l + r) / 2;
85            self._fold(i << 1, l, mid, x, y) + self._fold((i << 1) | 1, mid, r, x, y)
86        }
87    }
88
89    /// **Time complexity** $O(\log n)$
90    pub fn fold(&mut self, range: impl RangeBounds<usize>) -> T {
91        let (start, end) = range_bounds_to_range(range, 0, self.original_size);
92        self._fold(1, 0, self.hsize, start, end)
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use my_testtools::*;
100    use rand::Rng;
101    use std::ops::Range;
102
103    #[test]
104    fn test() {
105        #![allow(clippy::needless_range_loop)]
106        let mut rng = rand::thread_rng();
107        let n = 100;
108
109        let mut seg = SegtreeLinearAddRangeSum::<i64>::new(n);
110        let mut vec = vec![0; n];
111
112        for _ in 0..300 {
113            let Range { start: l, end: r } = rand_range(&mut rng, 0..n);
114
115            let a = rng.gen_range(0..100);
116            let b = rng.gen_range(0..100);
117
118            seg.update(l..r, Linear { a, b });
119
120            for i in l..r {
121                vec[i] += a * i as i64 + b;
122            }
123
124            let Range { start: l, end: r } = rand_range(&mut rng, 0..n);
125
126            let res = seg.fold(l..r);
127            let ans = vec[l..r].iter().sum::<i64>();
128
129            assert_eq!(res, ans);
130        }
131    }
132}