haar_lib/ds/
dynamic_segtree.rs

1//! 動的セグメント木
2//!
3//! # Problems
4//! - <https://judge.yosupo.jp/problem/point_set_range_composite_large_array>
5use crate::algebra::traits::Monoid;
6use crate::misc::nullable_usize::NullableUsize;
7use std::ops::Range;
8
9#[derive(Debug)]
10struct Node<T> {
11    value: T,
12    left: NullableUsize,
13    right: NullableUsize,
14}
15
16impl<T> Node<T> {
17    fn new(value: T) -> Self {
18        Self {
19            value,
20            left: NullableUsize::NULL,
21            right: NullableUsize::NULL,
22        }
23    }
24}
25
26/// 動的セグメント木
27#[derive(Debug)]
28pub struct DynamicSegtree<M: Monoid> {
29    data: Vec<Node<M>>,
30    root: NullableUsize,
31    to: usize,
32}
33
34impl<M: Monoid + Clone> Default for DynamicSegtree<M> {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl<M: Monoid + Clone> DynamicSegtree<M> {
41    /// [`DynamicSegtree<M>`]を生成する。
42    pub fn new() -> Self {
43        Self {
44            data: vec![Node::new(M::id())],
45            root: NullableUsize(0),
46            to: 1,
47        }
48    }
49
50    fn assign_dfs(
51        &mut self,
52        cur_id: NullableUsize,
53        cur_from: usize,
54        cur_to: usize,
55        i: usize,
56        value: M,
57    ) {
58        if cur_to - cur_from == 1 {
59            self.data[cur_id.0].value = value;
60        } else {
61            let mid = (cur_from + cur_to) / 2;
62            if (cur_from..mid).contains(&i) {
63                if self.data[cur_id.0].left.is_null() {
64                    let new_node = Node::new(value.clone());
65                    self.data.push(new_node);
66                    self.data[cur_id.0].left = NullableUsize(self.data.len() - 1);
67                }
68                self.assign_dfs(self.data[cur_id.0].left, cur_from, mid, i, value);
69            } else {
70                if self.data[cur_id.0].right.is_null() {
71                    let new_node = Node::new(value.clone());
72                    self.data.push(new_node);
73                    self.data[cur_id.0].right = NullableUsize(self.data.len() - 1);
74                }
75                self.assign_dfs(self.data[cur_id.0].right, mid, cur_to, i, value);
76            }
77
78            let left = self.data[cur_id.0].left;
79            let right = self.data[cur_id.0].right;
80
81            self.data[cur_id.0].value = M::op(
82                if left.is_null() {
83                    M::id()
84                } else {
85                    self.data[left.0].value.clone()
86                },
87                if right.is_null() {
88                    M::id()
89                } else {
90                    self.data[right.0].value.clone()
91                },
92            );
93        }
94    }
95
96    /// `i`番目の要素を`value`で更新する。
97    pub fn assign(&mut self, i: usize, value: M) {
98        loop {
99            if i < self.to {
100                break;
101            }
102
103            self.to *= 2;
104            let mut new_root = Node::new(self.data[self.root.0].value.clone());
105            new_root.left = self.root;
106            self.data.push(new_root);
107            self.root = NullableUsize(self.data.len() - 1);
108        }
109
110        self.assign_dfs(self.root, 0, self.to, i, value);
111    }
112
113    fn fold_dfs(
114        &self,
115        cur_id: NullableUsize,
116        cur_from: usize,
117        cur_to: usize,
118        from: usize,
119        to: usize,
120    ) -> M {
121        let cur = &self.data[cur_id.0];
122
123        if cur_to <= from || to <= cur_from {
124            M::id()
125        } else if from <= cur_from && cur_to <= to {
126            cur.value.clone()
127        } else {
128            let mid = (cur_from + cur_to) / 2;
129            let lv = if cur.left.is_null() {
130                M::id()
131            } else {
132                self.fold_dfs(cur.left, cur_from, mid, from, to)
133            };
134            let rv = if cur.right.is_null() {
135                M::id()
136            } else {
137                self.fold_dfs(cur.right, mid, cur_to, from, to)
138            };
139
140            M::op(lv, rv)
141        }
142    }
143
144    /// 範囲`start..end`で計算を集約する。
145    pub fn fold(&self, Range { start, end }: Range<usize>) -> M {
146        self.fold_dfs(self.root, 0, self.to, start, end)
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use std::collections::BTreeMap;
153
154    use super::*;
155    use crate::algebra::sum::*;
156    use my_testtools::rand_range;
157    use rand::Rng;
158
159    #[test]
160    fn test() {
161        let mut rng = rand::thread_rng();
162
163        let mut seg = DynamicSegtree::<Sum<u64>>::new();
164        let mut map = BTreeMap::<usize, Sum<u64>>::new();
165
166        let t = 100;
167
168        for _ in 0..t {
169            let i = rng.gen_range::<usize, _>(0..usize::MAX / 2);
170            let x = rng.gen::<u64>() % 1000000;
171
172            seg.assign(i, Sum(x));
173            map.entry(i).or_insert(Sum::id()).op_assign_r(Sum(x));
174
175            let lr = rand_range(&mut rng, 0..usize::MAX / 2);
176
177            let res = seg.fold(lr.clone());
178            let ans = map.range(lr).map(|(_, v)| v).cloned().fold_m();
179
180            assert_eq!(res, ans);
181        }
182    }
183}