haar_lib/ds/
disjoint_sparse_table.rs

1//! 半群の列の区間取得($O(1)$)ができる。
2
3pub use crate::algebra::traits::Semigroup;
4use crate::misc::range::range_bounds_to_range;
5use std::{iter::repeat_n, ops::RangeBounds};
6
7/// 半群の列の区間取得($O(1)$)ができる。
8pub struct DisjointSparseTable<S: Semigroup> {
9    semigroup: S,
10    data: Vec<Vec<Option<S::Element>>>,
11    seq: Vec<Option<S::Element>>,
12    size: usize,
13}
14
15impl<S: Semigroup> DisjointSparseTable<S>
16where
17    S::Element: Clone,
18{
19    /// 列`seq`から`DisjointSparseTable<S>`を構築する。
20    pub fn new(semigroup: S, seq: Vec<S::Element>) -> Self {
21        assert!(!seq.is_empty());
22
23        let size = seq.len();
24        let log_size = usize::BITS as usize - (size - 1).leading_zeros() as usize;
25        let mut data = vec![vec![None; 1 << log_size]; log_size];
26
27        let seq = seq
28            .into_iter()
29            .map(Some)
30            .chain(repeat_n(None, (1 << log_size) - size))
31            .collect::<Vec<_>>();
32
33        for (i, x) in seq.iter().enumerate() {
34            data[0][i] = x.clone();
35        }
36
37        let mut this = Self {
38            semigroup,
39            data,
40            seq,
41            size,
42        };
43        this.build(0, 1 << log_size, log_size - 1);
44
45        this
46    }
47
48    fn build(&mut self, l: usize, r: usize, d: usize) {
49        let m = (l + r) / 2;
50
51        self.data[d][m] = self.seq[m].clone();
52        for i in m + 1..r {
53            self.data[d][i] = match (self.data[d][i - 1].clone(), self.seq[i].clone()) {
54                (Some(x), Some(y)) => Some(self.semigroup.op(x, y)),
55                (a, None) => a,
56                (None, a) => a,
57            }
58        }
59
60        self.data[d][m - 1] = self.seq[m - 1].clone();
61        for i in (l..m - 1).rev() {
62            self.data[d][i] = match (self.seq[i].clone(), self.data[d][i + 1].clone()) {
63                (Some(x), Some(y)) => Some(self.semigroup.op(x, y)),
64                (a, None) => a,
65                (None, a) => a,
66            }
67        }
68
69        if d > 0 {
70            self.build(l, m, d - 1);
71            self.build(m, r, d - 1);
72        }
73    }
74
75    /// **Time complexity** $O(1)$
76    pub fn fold(&self, range: impl RangeBounds<usize>) -> Option<S::Element> {
77        let (l, r) = range_bounds_to_range(range, 0, self.size);
78
79        if l == r {
80            None
81        } else {
82            let r = r - 1;
83
84            if l == r {
85                self.seq[l].clone()
86            } else {
87                let k = usize::BITS as usize - 1 - (l ^ r).leading_zeros() as usize;
88                match (self.data[k][l].clone(), self.data[k][r].clone()) {
89                    (Some(x), Some(y)) => Some(self.semigroup.op(x, y)),
90                    (a, None) => a,
91                    (None, a) => a,
92                }
93            }
94        }
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use crate::algebra::sum::*;
102    use my_testtools::*;
103    use rand::Rng;
104
105    #[test]
106    fn test() {
107        let mut rng = rand::thread_rng();
108
109        let n = 100;
110        let a = std::iter::repeat_with(|| rng.gen::<u32>() % 10000)
111            .take(n)
112            .collect::<Vec<_>>();
113
114        let m = Sum::<u32>::new();
115        let s = DisjointSparseTable::new(m, a.clone());
116
117        for _ in 0..100 {
118            let lr = rand_range(&mut rng, 0..n);
119
120            assert_eq!(
121                s.fold(lr.clone()),
122                if lr.is_empty() {
123                    None
124                } else {
125                    Some(a[lr].iter().cloned().fold_m(&m))
126                }
127            );
128        }
129    }
130}