haar_lib/ds/
segtree_linear_add.rs

1//! 区間一次関数加算セグメントツリー
2//!
3//! # Problems
4//!
5//! - [HUPC 2020 B 三角形足し算](https://onlinejudge.u-aizu.ac.jp/challenges/sources/VPC/HUPC/3165?year=2020)
6
7use crate::math::linear::*;
8use crate::misc::range::range_bounds_to_range;
9use crate::num::one_zero::Zero;
10use std::{
11    cell::Cell,
12    mem::size_of,
13    ops::{Add, Mul, RangeBounds},
14};
15
16/// 区間一次関数加算セグメントツリー
17pub struct SegtreeLinearAdd<T> {
18    hsize: usize,
19    original_size: usize,
20    data: Vec<Cell<(T, T)>>,
21    from: Vec<usize>,
22}
23
24fn add<T: Add<Output = T>>((a, b): (T, T), (c, d): (T, T)) -> (T, T) {
25    (a + c, b + d)
26}
27
28impl<T> SegtreeLinearAdd<T>
29where
30    T: Copy + Add<Output = T> + Mul<Output = T> + Zero + From<u32>,
31{
32    /// **Time complexity** $O(n)$
33    pub fn new(n: usize) -> Self {
34        let size = n.next_power_of_two() * 2;
35        let hsize = size / 2;
36        let mut from = vec![0; size];
37
38        let mut s = 0;
39        for (i, x) in from.iter_mut().enumerate().skip(1) {
40            *x = s;
41            let l = hsize >> (size_of::<usize>() as u32 * 8 - 1 - i.leading_zeros());
42            s += l;
43            if s == hsize {
44                s = 0;
45            }
46        }
47
48        Self {
49            hsize,
50            original_size: n,
51            data: vec![Cell::new((T::zero(), T::zero())); size],
52            from,
53        }
54    }
55
56    /// 範囲`l..r`に一次関数`ax + b`の値を加算する。(`x`の値は`l..r`の範囲)
57    ///
58    /// **Time complexity** $O(\log n)$
59    pub fn update(&mut self, range: impl RangeBounds<usize>, linear: Linear<T>) {
60        let (l, r) = range_bounds_to_range(range, 0, self.original_size);
61
62        let mut l = l + self.hsize;
63        let mut r = r + self.hsize;
64
65        while l < r {
66            if r & 1 == 1 {
67                r -= 1;
68                self.data[r].set(add(
69                    self.data[r].get(),
70                    (linear.apply(T::from(self.from[r] as u32)), linear.a),
71                ));
72            }
73            if l & 1 == 1 {
74                self.data[l].set(add(
75                    self.data[l].get(),
76                    (linear.apply(T::from(self.from[l] as u32)), linear.a),
77                ));
78                l += 1;
79            }
80
81            l >>= 1;
82            r >>= 1;
83        }
84    }
85
86    fn propagate(&self, i: usize) {
87        if i < self.hsize {
88            self.data[i << 1].set(add(self.data[i << 1].get(), self.data[i].get()));
89
90            let len = self.hsize >> (size_of::<usize>() as u32 * 8 - i.leading_zeros());
91            self.data[i].set((
92                self.data[i].get().0 + self.data[i].get().1 * T::from(len as u32),
93                self.data[i].get().1,
94            ));
95            self.data[(i << 1) | 1].set(add(self.data[(i << 1) | 1].get(), self.data[i].get()));
96
97            self.data[i].set((T::zero(), T::zero()));
98        }
99    }
100
101    fn propagate_top_down(&self, mut i: usize) {
102        let mut temp = vec![];
103        while i > 1 {
104            i >>= 1;
105            temp.push(i);
106        }
107
108        for i in temp.into_iter().rev() {
109            self.propagate(i);
110        }
111    }
112
113    /// **Time complexity** $O(\log n)$
114    pub fn get(&self, i: usize) -> T {
115        self.propagate_top_down(i + self.hsize);
116        self.data[i + self.hsize].get().0
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use my_testtools::*;
124    use rand::Rng;
125
126    #[test]
127    fn test() {
128        let mut rng = rand::thread_rng();
129
130        let n = 100;
131        let mut seg = SegtreeLinearAdd::<u64>::new(n);
132        let mut vec = vec![0; n];
133
134        for _ in 0..300 {
135            let lr = rand_range(&mut rng, 0..n);
136
137            let a = rng.gen_range(0..100);
138            let b = rng.gen_range(0..100);
139
140            seg.update(lr.clone(), Linear { a, b });
141
142            for i in lr {
143                vec[i] += a * i as u64 + b;
144            }
145
146            assert_eq!((0..n).map(|i| seg.get(i)).collect::<Vec<_>>(), vec);
147        }
148    }
149}