haar_lib/ds/
starry_sky_tree.rs

1//! 区間加算・区間Max(Min)
2
3use crate::misc::range::range_bounds_to_range;
4use crate::num::one_zero::Zero;
5use std::{
6    cmp::{max, min},
7    ops::{Add, RangeBounds, Sub},
8};
9
10/// 区間Max/Minを選択する。
11#[derive(Copy, Clone)]
12pub enum Mode {
13    /// 区間Max
14    Max,
15    /// 区間Min
16    Min,
17}
18
19impl Mode {
20    fn op<T: Ord>(self, a: T, b: T) -> T {
21        match self {
22            Mode::Max => max(a, b),
23            Mode::Min => min(a, b),
24        }
25    }
26}
27
28/// 区間加算・区間Max(Min)ができるデータ構造。
29pub struct StarrySkyTree<T> {
30    size: usize,
31    original_size: usize,
32    data: Vec<T>,
33    mode: Mode,
34}
35
36impl<T> StarrySkyTree<T>
37where
38    T: Add<Output = T> + Sub<Output = T> + Ord + Copy + Zero,
39{
40    /// **Time complexity** $O(n)$
41    pub fn new(n: usize, mode: Mode) -> Self {
42        let size = n.next_power_of_two() * 2;
43        let zero = T::zero();
44        Self {
45            size,
46            original_size: n,
47            data: vec![zero; size],
48            mode,
49        }
50    }
51
52    /// [`Vec`]から[`StarrySkyTree`]を構築する。
53    ///
54    /// **Time complexity** $O(|a|)$
55    pub fn from_vec(a: Vec<T>, mode: Mode) -> Self {
56        let mut this = Self::new(a.len(), mode);
57
58        for (i, x) in a.into_iter().enumerate() {
59            this.data[i + this.size / 2] = x;
60        }
61
62        for i in (1..this.size / 2).rev() {
63            let d = mode.op(this.data[i << 1], this.data[(i << 1) | 1]);
64
65            this.data[i << 1] = this.data[i << 1] - d;
66            this.data[(i << 1) | 1] = this.data[(i << 1) | 1] - d;
67            this.data[i] = this.data[i] + d;
68        }
69
70        this
71    }
72
73    fn rec(&self, s: usize, t: usize, i: usize, l: usize, r: usize, value: T) -> Option<T> {
74        if r <= s || t <= l {
75            return None;
76        }
77        if s <= l && r <= t {
78            return Some(value + self.data[i]);
79        }
80
81        let a = self.rec(s, t, i << 1, l, (l + r) / 2, value + self.data[i]);
82        let b = self.rec(s, t, (i << 1) | 1, (l + r) / 2, r, value + self.data[i]);
83
84        match (a, b) {
85            (None, _) => b,
86            (_, None) => a,
87            (Some(a), Some(b)) => Some(self.mode.op(a, b)),
88        }
89    }
90
91    /// **Time complexity** $O(\log n)$
92    pub fn fold(&self, range: impl RangeBounds<usize>) -> Option<T> {
93        let (l, r) = range_bounds_to_range(range, 0, self.original_size);
94        self.rec(l, r, 1, 0, self.size / 2, T::zero())
95    }
96
97    fn bottom_up(&mut self, mut i: usize) {
98        if i > self.size {
99            return;
100        }
101
102        while i >= 1 {
103            if i < self.size / 2 {
104                let d = self.mode.op(self.data[i << 1], self.data[(i << 1) | 1]);
105
106                self.data[i << 1] = self.data[i << 1] - d;
107                self.data[(i << 1) | 1] = self.data[(i << 1) | 1] - d;
108                self.data[i] = self.data[i] + d;
109            }
110
111            i >>= 1;
112        }
113    }
114
115    /// **Time complexity** $O(\log n)$
116    pub fn update(&mut self, range: impl RangeBounds<usize>, value: T) {
117        let (l, r) = range_bounds_to_range(range, 0, self.original_size);
118
119        let hsize = self.size / 2;
120        let mut ll = l + hsize;
121        let mut rr = r + hsize;
122
123        while ll < rr {
124            if (rr & 1) != 0 {
125                rr -= 1;
126                self.data[rr] = self.data[rr] + value;
127            }
128            if (ll & 1) != 0 {
129                self.data[ll] = self.data[ll] + value;
130                ll += 1;
131            }
132            ll >>= 1;
133            rr >>= 1;
134        }
135
136        self.bottom_up(l + hsize);
137        self.bottom_up(r + hsize);
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use my_testtools::*;
145    use rand::Rng;
146
147    #[test]
148    fn test_max() {
149        let mut rng = rand::thread_rng();
150
151        let size = 100;
152        let mut other = vec![0; size];
153        let mut s = StarrySkyTree::<i32>::new(size, Mode::Max);
154
155        for _ in 0..1000 {
156            let ty = rng.gen_range(0..2);
157            let lr = rand_range(&mut rng, 0..size);
158
159            if ty == 0 {
160                let x = rng.gen_range(-1000..=1000);
161
162                s.update(lr.clone(), x);
163                for i in lr {
164                    other[i] += x;
165                }
166            } else {
167                let ans = lr.clone().map(|i| other[i]).max();
168
169                assert_eq!(s.fold(lr), ans);
170            }
171        }
172    }
173
174    #[test]
175    fn test_min() {
176        let mut rng = rand::thread_rng();
177
178        let size = 100;
179        let mut other = vec![0; size];
180        let mut s = StarrySkyTree::<i32>::new(size, Mode::Min);
181
182        for _ in 0..1000 {
183            let ty = rng.gen_range(0..2);
184            let lr = rand_range(&mut rng, 0..size);
185
186            if ty == 0 {
187                let x = rng.gen_range(-1000..=1000);
188
189                s.update(lr.clone(), x);
190                for i in lr {
191                    other[i] += x;
192                }
193            } else {
194                let ans = lr.clone().map(|i| other[i]).min();
195
196                assert_eq!(s.fold(lr), ans);
197            }
198        }
199    }
200}