haar_lib/ds/
segtree_linear_add.rs

1//! 区間一次関数加算セグメントツリー
2//!
3//! # Problems
4//!
5//! - [HUPC 2020 B 三角形足し算](https://onlinejudge.u-aizu.ac.jp/challenges/sources/VPC/HUPC/3165?year=2020)
6
7use crate::math::linear::*;
8use crate::misc::range::range_bounds_to_range;
9use crate::num::one_zero::Zero;
10use crate::trait_alias;
11use std::{
12    cell::Cell,
13    mem::size_of,
14    ops::{Add, Mul, RangeBounds},
15};
16
17trait_alias!(
18    /// [`SegtreeLinearAdd<T>`]が扱える型
19    Elem: Copy + Add<Output = Self> + Mul<Output = Self> + Zero + From<u32>
20);
21
22/// 区間一次関数加算セグメントツリー
23pub struct SegtreeLinearAdd<T> {
24    hsize: usize,
25    original_size: usize,
26    data: Vec<Cell<(T, T)>>,
27    from: Vec<usize>,
28}
29
30fn add<T: Add<Output = T>>((a, b): (T, T), (c, d): (T, T)) -> (T, T) {
31    (a + c, b + d)
32}
33
34impl<T: Elem> SegtreeLinearAdd<T> {
35    /// **Time complexity** $O(n)$
36    pub fn new(n: usize) -> Self {
37        let size = n.next_power_of_two() * 2;
38        let hsize = size / 2;
39        let mut from = vec![0; size];
40
41        let mut s = 0;
42        for (i, x) in from.iter_mut().enumerate().skip(1) {
43            *x = s;
44            let l = hsize >> (size_of::<usize>() as u32 * 8 - 1 - i.leading_zeros());
45            s += l;
46            if s == hsize {
47                s = 0;
48            }
49        }
50
51        Self {
52            hsize,
53            original_size: n,
54            data: vec![Cell::new((T::zero(), T::zero())); size],
55            from,
56        }
57    }
58
59    /// 範囲`l..r`に一次関数`ax + b`の値を加算する。(`x`の値は`l..r`の範囲)
60    ///
61    /// **Time complexity** $O(\log n)$
62    pub fn update(&mut self, range: impl RangeBounds<usize>, linear: Linear<T>) {
63        let (l, r) = range_bounds_to_range(range, 0, self.original_size);
64
65        let mut l = l + self.hsize;
66        let mut r = r + self.hsize;
67
68        while l < r {
69            if r & 1 == 1 {
70                r -= 1;
71                self.data[r].set(add(
72                    self.data[r].get(),
73                    (linear.apply(T::from(self.from[r] as u32)), linear.a),
74                ));
75            }
76            if l & 1 == 1 {
77                self.data[l].set(add(
78                    self.data[l].get(),
79                    (linear.apply(T::from(self.from[l] as u32)), linear.a),
80                ));
81                l += 1;
82            }
83
84            l >>= 1;
85            r >>= 1;
86        }
87    }
88
89    fn propagate(&self, i: usize) {
90        if i < self.hsize {
91            self.data[i << 1].set(add(self.data[i << 1].get(), self.data[i].get()));
92
93            let len = self.hsize >> (size_of::<usize>() as u32 * 8 - i.leading_zeros());
94            self.data[i].set((
95                self.data[i].get().0 + self.data[i].get().1 * T::from(len as u32),
96                self.data[i].get().1,
97            ));
98            self.data[(i << 1) | 1].set(add(self.data[(i << 1) | 1].get(), self.data[i].get()));
99
100            self.data[i].set((T::zero(), T::zero()));
101        }
102    }
103
104    fn propagate_top_down(&self, mut i: usize) {
105        let mut temp = vec![];
106        while i > 1 {
107            i >>= 1;
108            temp.push(i);
109        }
110
111        for i in temp.into_iter().rev() {
112            self.propagate(i);
113        }
114    }
115
116    /// **Time complexity** $O(\log n)$
117    pub fn get(&self, i: usize) -> T {
118        self.propagate_top_down(i + self.hsize);
119        self.data[i + self.hsize].get().0
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use my_testtools::*;
127    use rand::Rng;
128
129    #[test]
130    fn test() {
131        let mut rng = rand::thread_rng();
132
133        let n = 100;
134        let mut seg = SegtreeLinearAdd::<u64>::new(n);
135        let mut vec = vec![0; n];
136
137        for _ in 0..300 {
138            let lr = rand_range(&mut rng, 0..n);
139
140            let a = rng.gen_range(0..100);
141            let b = rng.gen_range(0..100);
142
143            seg.update(lr.clone(), Linear { a, b });
144
145            for i in lr {
146                vec[i] += a * i as u64 + b;
147            }
148
149            assert_eq!((0..n).map(|i| seg.get(i)).collect::<Vec<_>>(), vec);
150        }
151    }
152}