haar_lib/ds/
persistent_segtree.rs

1//! 永続セグメントツリー
2
3use std::ops::RangeBounds;
4use std::ptr;
5
6use crate::algebra::traits::Monoid;
7use crate::misc::range::range_bounds_to_range;
8
9#[derive(Clone, Debug)]
10struct Node<T> {
11    value: T,
12    left: *mut Self,
13    right: *mut Self,
14}
15
16impl<T> Node<T> {
17    fn new(value: T) -> Self {
18        Self {
19            value,
20            left: ptr::null_mut(),
21            right: ptr::null_mut(),
22        }
23    }
24}
25
26/// 永続セグメントツリー
27#[derive(Clone, Debug)]
28pub struct PersistentSegtree<M: Monoid> {
29    monoid: M,
30    root: *mut Node<M::Element>,
31    to: usize,
32    original_size: usize,
33}
34
35impl<M: Monoid + Clone> PersistentSegtree<M>
36where
37    M::Element: Clone,
38{
39    /// 長さ`n`の[`PersistentSegtree`]を生成する。
40    pub fn new(monoid: M, n: usize) -> Self {
41        let seq = vec![monoid.id(); n];
42        Self::from_vec(monoid, seq)
43    }
44
45    /// [`Vec`]から[`PersistentSegtree`]を構築する。
46    pub fn from_vec(monoid: M, a: Vec<M::Element>) -> Self {
47        let n = a.len();
48        let to = n.next_power_of_two();
49        let root = Self::__init(&monoid, 0, to, &a);
50        Self {
51            monoid,
52            root,
53            to,
54            original_size: n,
55        }
56    }
57
58    fn __init(monoid: &M, from: usize, to: usize, seq: &[M::Element]) -> *mut Node<M::Element> {
59        if to - from == 1 {
60            Box::into_raw(Box::new(Node::new(seq[from].clone())))
61        } else {
62            let mid = (from + to) / 2;
63            let mut node = Node::new(monoid.id());
64
65            let lv = if seq.len() > from {
66                let left = Self::__init(monoid, from, mid, seq);
67                node.left = left;
68                assert!(!left.is_null());
69                unsafe { (*left).value.clone() }
70            } else {
71                monoid.id()
72            };
73
74            let rv = if seq.len() > mid {
75                let right = Self::__init(monoid, mid, to, seq);
76                node.right = right;
77                assert!(!right.is_null());
78                unsafe { (*right).value.clone() }
79            } else {
80                monoid.id()
81            };
82
83            node.value = monoid.op(lv, rv);
84
85            Box::into_raw(Box::new(node))
86        }
87    }
88
89    fn __set(
90        monoid: &M,
91        node: *mut Node<M::Element>,
92        from: usize,
93        to: usize,
94        pos: usize,
95        value: &M::Element,
96    ) -> *mut Node<M::Element> {
97        assert!(!node.is_null());
98
99        if to <= pos || pos < from {
100            node
101        } else if pos <= from && to <= pos + 1 {
102            Box::into_raw(Box::new(Node::new(value.clone())))
103        } else {
104            let mid = (from + to) / 2;
105
106            let left = unsafe { (*node).left };
107            let right = unsafe { (*node).right };
108
109            let lp = if !left.is_null() {
110                Self::__set(monoid, left, from, mid, pos, value)
111            } else {
112                left
113            };
114
115            let rp = if !right.is_null() {
116                Self::__set(monoid, right, mid, to, pos, value)
117            } else {
118                right
119            };
120
121            let mut value = monoid.id();
122            if !lp.is_null() {
123                value = monoid.op(value, unsafe { (*lp).value.clone() });
124            }
125            if !rp.is_null() {
126                value = monoid.op(value, unsafe { (*rp).value.clone() });
127            }
128
129            let mut s = Node::new(value);
130            s.left = lp;
131            s.right = rp;
132
133            Box::into_raw(Box::new(s))
134        }
135    }
136
137    /// `i`番目の要素を`value`にする。
138    pub fn assign(&self, i: usize, value: M::Element) -> Self {
139        let new_root = Self::__set(&self.monoid, self.root, 0, self.to, i, &value);
140
141        Self {
142            monoid: self.monoid.clone(),
143            root: new_root,
144            to: self.to,
145            original_size: self.original_size,
146        }
147    }
148
149    fn __fold(
150        monoid: &M,
151        node: *mut Node<M::Element>,
152        from: usize,
153        to: usize,
154        l: usize,
155        r: usize,
156    ) -> M::Element {
157        assert!(!node.is_null());
158
159        if l <= from && to <= r {
160            unsafe { (*node).value.clone() }
161        } else if to <= l || r <= from {
162            monoid.id()
163        } else {
164            let mid = (from + to) / 2;
165
166            let left = unsafe { (*node).left };
167            let right = unsafe { (*node).right };
168
169            let lv = if left.is_null() {
170                monoid.id()
171            } else {
172                Self::__fold(monoid, left, from, mid, l, r)
173            };
174
175            let rv = if right.is_null() {
176                monoid.id()
177            } else {
178                Self::__fold(monoid, right, mid, to, l, r)
179            };
180
181            monoid.op(lv, rv)
182        }
183    }
184
185    /// 範囲`range`で計算を集約して返す。
186    pub fn fold(&self, range: impl RangeBounds<usize>) -> M::Element {
187        let (start, end) = range_bounds_to_range(range, 0, self.original_size);
188        Self::__fold(&self.monoid, self.root, 0, self.to, start, end)
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use crate::algebra::sum::*;
196
197    #[test]
198    fn test() {
199        let m = Sum::<u64>::new();
200
201        let a = vec![0, 1, 3, 9, 4, 8, 2];
202        let seg = PersistentSegtree::from_vec(m, a);
203
204        dbg!(seg.fold(0..5));
205
206        let s1 = seg.assign(0, 10);
207
208        dbg!(s1.fold(0..5));
209        dbg!(seg.fold(0..5));
210
211        let s2 = seg.assign(2, 6);
212
213        dbg!(s1.fold(0..5));
214        dbg!(s2.fold(0..5));
215        dbg!(seg.fold(0..5));
216    }
217}