haar_lib/ds/
dynamic_segtree.rs

1//! 動的セグメント木
2//!
3//! # Problems
4//! - <https://judge.yosupo.jp/problem/point_set_range_composite_large_array>
5use crate::algebra::traits::Monoid;
6use crate::misc::nullable_usize::NullableUsize;
7use std::ops::Range;
8
9#[derive(Debug)]
10struct Node<T> {
11    value: T,
12    left: NullableUsize,
13    right: NullableUsize,
14}
15
16impl<T> Node<T> {
17    fn new(value: T) -> Self {
18        Self {
19            value,
20            left: NullableUsize::NULL,
21            right: NullableUsize::NULL,
22        }
23    }
24}
25
26/// 動的セグメント木
27#[derive(Debug)]
28pub struct DynamicSegtree<M: Monoid> {
29    monoid: M,
30    data: Vec<Node<M::Element>>,
31    root: NullableUsize,
32    to: usize,
33}
34
35impl<M: Monoid> DynamicSegtree<M>
36where
37    M::Element: Clone,
38{
39    /// [`DynamicSegtree<M>`]を生成する。
40    pub fn new(monoid: M) -> Self {
41        Self {
42            data: vec![Node::new(monoid.id())],
43            root: NullableUsize(0),
44            to: 1,
45            monoid,
46        }
47    }
48
49    fn assign_dfs(
50        &mut self,
51        cur_id: NullableUsize,
52        cur_from: usize,
53        cur_to: usize,
54        i: usize,
55        value: M::Element,
56    ) {
57        if cur_to - cur_from == 1 {
58            self.data[cur_id.0].value = value;
59        } else {
60            let mid = (cur_from + cur_to) / 2;
61            if (cur_from..mid).contains(&i) {
62                if self.data[cur_id.0].left.is_null() {
63                    let new_node = Node::new(value.clone());
64                    self.data.push(new_node);
65                    self.data[cur_id.0].left = NullableUsize(self.data.len() - 1);
66                }
67                self.assign_dfs(self.data[cur_id.0].left, cur_from, mid, i, value);
68            } else {
69                if self.data[cur_id.0].right.is_null() {
70                    let new_node = Node::new(value.clone());
71                    self.data.push(new_node);
72                    self.data[cur_id.0].right = NullableUsize(self.data.len() - 1);
73                }
74                self.assign_dfs(self.data[cur_id.0].right, mid, cur_to, i, value);
75            }
76
77            let left = self.data[cur_id.0].left;
78            let right = self.data[cur_id.0].right;
79
80            self.data[cur_id.0].value = self.monoid.op(
81                if left.is_null() {
82                    self.monoid.id()
83                } else {
84                    self.data[left.0].value.clone()
85                },
86                if right.is_null() {
87                    self.monoid.id()
88                } else {
89                    self.data[right.0].value.clone()
90                },
91            );
92        }
93    }
94
95    /// `i`番目の要素を`value`で更新する。
96    pub fn assign(&mut self, i: usize, value: M::Element) {
97        loop {
98            if i < self.to {
99                break;
100            }
101
102            self.to *= 2;
103            let mut new_root = Node::new(self.data[self.root.0].value.clone());
104            new_root.left = self.root;
105            self.data.push(new_root);
106            self.root = NullableUsize(self.data.len() - 1);
107        }
108
109        self.assign_dfs(self.root, 0, self.to, i, value);
110    }
111
112    fn fold_dfs(
113        &self,
114        cur_id: NullableUsize,
115        cur_from: usize,
116        cur_to: usize,
117        from: usize,
118        to: usize,
119    ) -> M::Element {
120        let cur = &self.data[cur_id.0];
121
122        if cur_to <= from || to <= cur_from {
123            self.monoid.id()
124        } else if from <= cur_from && cur_to <= to {
125            cur.value.clone()
126        } else {
127            let mid = (cur_from + cur_to) / 2;
128            let lv = if cur.left.is_null() {
129                self.monoid.id()
130            } else {
131                self.fold_dfs(cur.left, cur_from, mid, from, to)
132            };
133            let rv = if cur.right.is_null() {
134                self.monoid.id()
135            } else {
136                self.fold_dfs(cur.right, mid, cur_to, from, to)
137            };
138
139            self.monoid.op(lv, rv)
140        }
141    }
142
143    /// 範囲`start..end`で計算を集約する。
144    pub fn fold(&self, Range { start, end }: Range<usize>) -> M::Element {
145        self.fold_dfs(self.root, 0, self.to, start, end)
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use std::collections::BTreeMap;
152
153    use super::*;
154    use crate::algebra::sum::*;
155    use my_testtools::rand_range;
156    use rand::Rng;
157
158    #[test]
159    fn test() {
160        let mut rng = rand::thread_rng();
161
162        let m = Sum::<u64>::new();
163
164        let mut seg = DynamicSegtree::new(m);
165        let mut map = BTreeMap::<usize, u64>::new();
166
167        let t = 100;
168
169        for _ in 0..t {
170            let i = rng.gen_range::<usize, _>(0..usize::MAX / 2);
171            let x = rng.gen::<u64>() % 1000000;
172
173            seg.assign(i, x);
174            m.op_assign_r(map.entry(i).or_insert(m.id()), x);
175
176            let lr = rand_range(&mut rng, 0..usize::MAX / 2);
177
178            let res = seg.fold(lr.clone());
179            let ans = map.range(lr).map(|(_, v)| v).cloned().fold_m(&m);
180
181            assert_eq!(res, ans);
182        }
183    }
184}