haar_lib/ds/
segtree_linear_add.rs1use crate::math::linear::*;
8use crate::misc::range::range_bounds_to_range;
9use crate::num::one_zero::Zero;
10use crate::trait_alias;
11use std::{
12 cell::Cell,
13 mem::size_of,
14 ops::{Add, Mul, RangeBounds},
15};
16
17trait_alias!(
18 Elem: Copy + Add<Output = Self> + Mul<Output = Self> + Zero + From<u32>
20);
21
22pub struct SegtreeLinearAdd<T> {
24 hsize: usize,
25 original_size: usize,
26 data: Vec<Cell<(T, T)>>,
27 from: Vec<usize>,
28}
29
30fn add<T: Add<Output = T>>((a, b): (T, T), (c, d): (T, T)) -> (T, T) {
31 (a + c, b + d)
32}
33
34impl<T: Elem> SegtreeLinearAdd<T> {
35 pub fn new(n: usize) -> Self {
37 let size = n.next_power_of_two() * 2;
38 let hsize = size / 2;
39 let mut from = vec![0; size];
40
41 let mut s = 0;
42 for (i, x) in from.iter_mut().enumerate().skip(1) {
43 *x = s;
44 let l = hsize >> (size_of::<usize>() as u32 * 8 - 1 - i.leading_zeros());
45 s += l;
46 if s == hsize {
47 s = 0;
48 }
49 }
50
51 Self {
52 hsize,
53 original_size: n,
54 data: vec![Cell::new((T::zero(), T::zero())); size],
55 from,
56 }
57 }
58
59 pub fn update(&mut self, range: impl RangeBounds<usize>, linear: Linear<T>) {
63 let (l, r) = range_bounds_to_range(range, 0, self.original_size);
64
65 let mut l = l + self.hsize;
66 let mut r = r + self.hsize;
67
68 while l < r {
69 if r & 1 == 1 {
70 r -= 1;
71 self.data[r].set(add(
72 self.data[r].get(),
73 (linear.apply(T::from(self.from[r] as u32)), linear.a),
74 ));
75 }
76 if l & 1 == 1 {
77 self.data[l].set(add(
78 self.data[l].get(),
79 (linear.apply(T::from(self.from[l] as u32)), linear.a),
80 ));
81 l += 1;
82 }
83
84 l >>= 1;
85 r >>= 1;
86 }
87 }
88
89 fn propagate(&self, i: usize) {
90 if i < self.hsize {
91 self.data[i << 1].set(add(self.data[i << 1].get(), self.data[i].get()));
92
93 let len = self.hsize >> (size_of::<usize>() as u32 * 8 - i.leading_zeros());
94 self.data[i].set((
95 self.data[i].get().0 + self.data[i].get().1 * T::from(len as u32),
96 self.data[i].get().1,
97 ));
98 self.data[(i << 1) | 1].set(add(self.data[(i << 1) | 1].get(), self.data[i].get()));
99
100 self.data[i].set((T::zero(), T::zero()));
101 }
102 }
103
104 fn propagate_top_down(&self, mut i: usize) {
105 let mut temp = vec![];
106 while i > 1 {
107 i >>= 1;
108 temp.push(i);
109 }
110
111 for i in temp.into_iter().rev() {
112 self.propagate(i);
113 }
114 }
115
116 pub fn get(&self, i: usize) -> T {
118 self.propagate_top_down(i + self.hsize);
119 self.data[i + self.hsize].get().0
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use my_testtools::*;
127 use rand::Rng;
128
129 #[test]
130 fn test() {
131 let mut rng = rand::thread_rng();
132
133 let n = 100;
134 let mut seg = SegtreeLinearAdd::<u64>::new(n);
135 let mut vec = vec![0; n];
136
137 for _ in 0..300 {
138 let lr = rand_range(&mut rng, 0..n);
139
140 let a = rng.gen_range(0..100);
141 let b = rng.gen_range(0..100);
142
143 seg.update(lr.clone(), Linear { a, b });
144
145 for i in lr {
146 vec[i] += a * i as u64 + b;
147 }
148
149 assert_eq!((0..n).map(|i| seg.get(i)).collect::<Vec<_>>(), vec);
150 }
151 }
152}