haar_lib/ds/
lazy_segtree_coeff.rs

1//! 係数乗算付き区間加算区間総和遅延セグ木
2
3use crate::misc::range::range_bounds_to_range;
4use crate::num::one_zero::Zero;
5use std::cell::Cell;
6use std::ops::{Add, Mul, RangeBounds};
7
8/// 係数乗算付き区間加算区間総和遅延セグ木
9pub struct LazySegtreeCoeff<T, U = T> {
10    size: usize,
11    original_size: usize,
12    data: Vec<Cell<T>>,
13    lazy: Vec<Cell<T>>,
14    coeff: Vec<U>,
15}
16
17impl<T, U> LazySegtreeCoeff<T, U>
18where
19    T: Copy + Zero + Add<Output = T> + Mul<U, Output = T> + PartialEq,
20    U: Copy + Default + Add<Output = U>,
21{
22    /// ‍係数`coefficients`を設定した[`LazySegtreeCoeff`]を構築する。
23    pub fn new(n: usize, coefficients: Vec<U>) -> Self {
24        let size = n.next_power_of_two() * 2;
25
26        let mut coeff = vec![U::default(); size];
27
28        for i in 0..coefficients.len() {
29            coeff[i + size / 2] = coefficients[i];
30        }
31        for i in (1..size / 2).rev() {
32            coeff[i] = coeff[i << 1] + coeff[(i << 1) | 1];
33        }
34
35        Self {
36            size,
37            original_size: n,
38            data: vec![Cell::new(T::zero()); size],
39            lazy: vec![Cell::new(T::zero()); size],
40            coeff,
41        }
42    }
43
44    /// `self.fold(i..i+1) = value[i]`となるように割り当てる。
45    pub fn set_vec(&mut self, value: Vec<T>) {
46        self.data = vec![Cell::new(T::zero()); self.size];
47        self.lazy = vec![Cell::new(T::zero()); self.size];
48
49        for (i, x) in value.into_iter().enumerate() {
50            self.data[self.size / 2 + i].set(x);
51        }
52        for i in (1..self.size / 2).rev() {
53            self.data[i].set(self.data[i << 1].get() + self.data[(i << 1) | 1].get());
54        }
55    }
56
57    fn propagate(&self, i: usize) {
58        if self.lazy[i].get() != T::zero() {
59            if i < self.size / 2 {
60                self.lazy[i << 1].set(self.lazy[i].get() + self.lazy[i << 1].get());
61                self.lazy[(i << 1) | 1].set(self.lazy[i].get() + self.lazy[(i << 1) | 1].get());
62            }
63            self.data[i].set(self.data[i].get() + self.lazy[i].get() * self.coeff[i]);
64            self.lazy[i].set(T::zero());
65        }
66    }
67
68    fn _update(&mut self, i: usize, l: usize, r: usize, s: usize, t: usize, value: T) -> T {
69        self.propagate(i);
70        if r <= s || t <= l {
71            return self.data[i].get();
72        }
73        if s <= l && r <= t {
74            self.lazy[i].set(self.lazy[i].get() + value);
75            self.propagate(i);
76            return self.data[i].get();
77        }
78
79        let m = (l + r) / 2;
80        let t =
81            self._update(i << 1, l, m, s, t, value) + self._update((i << 1) | 1, m, r, s, t, value);
82
83        self.data[i].set(t);
84        t
85    }
86
87    fn _fold(&self, i: usize, l: usize, r: usize, x: usize, y: usize) -> T {
88        self.propagate(i);
89        if r <= x || y <= l {
90            return T::zero();
91        }
92        if x <= l && r <= y {
93            return self.data[i].get();
94        }
95
96        let m = (l + r) / 2;
97        self._fold(i << 1, l, m, x, y) + self._fold((i << 1) | 1, m, r, x, y)
98    }
99
100    /// 範囲`range`に値`value`を加算する。
101    pub fn update(&mut self, range: impl RangeBounds<usize>, value: T) {
102        let (start, end) = range_bounds_to_range(range, 0, self.original_size);
103        self._update(1, 0, self.size / 2, start, end, value);
104    }
105
106    /// 範囲`range`で総和を取る。
107    pub fn fold(&self, range: impl RangeBounds<usize>) -> T {
108        let (start, end) = range_bounds_to_range(range, 0, self.original_size);
109        self._fold(1, 0, self.size / 2, start, end)
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use std::iter::repeat_with;
116
117    use crate::{iter::collect::CollectVec, math::prime_mod::Prime, num::const_modint::*};
118
119    use super::*;
120    use my_testtools::rand_range;
121    use rand::Rng;
122
123    #[test]
124    fn test() {
125        let n = 100;
126        let q = 1000;
127
128        let modulo = ConstModIntBuilder::<Prime<998244353>>::new();
129
130        let mut rng = rand::thread_rng();
131
132        let mut a = repeat_with(|| modulo.from_u64(rng.gen_range(0..10)))
133            .take(n)
134            .collect_vec();
135
136        let c = repeat_with(|| modulo.from_u64(rng.gen_range(0..10)))
137            .take(n)
138            .collect_vec();
139
140        let mut seg = LazySegtreeCoeff::new(n, c.clone());
141        seg.set_vec(a.clone());
142
143        for _ in 0..q {
144            let range = rand_range(&mut rng, 0..n);
145
146            let value = modulo.from_u64(rng.gen_range(0..10));
147            seg.update(range.clone(), value);
148
149            for i in range {
150                a[i] += c[i] * value;
151            }
152
153            let range = rand_range(&mut rng, 0..n);
154
155            let mut ans = modulo.from_u64(0);
156            for i in range.clone() {
157                ans += a[i];
158            }
159
160            assert_eq!(seg.fold(range), ans);
161        }
162    }
163}