haar_lib/ds/
persistent_segtree.rs

1//! 永続セグメントツリー
2
3use std::cell::RefCell;
4use std::ops::RangeBounds;
5use std::rc::Rc;
6
7use crate::algebra::traits::Monoid;
8use crate::misc::range::range_bounds_to_range;
9
10#[derive(Clone, Debug)]
11struct Node<T> {
12    value: T,
13    left: Option<Rc<RefCell<Node<T>>>>,
14    right: Option<Rc<RefCell<Node<T>>>>,
15}
16
17impl<T> Node<T> {
18    fn new(value: T) -> Self {
19        Self {
20            value,
21            left: None,
22            right: None,
23        }
24    }
25}
26
27/// 永続セグメントツリー
28#[derive(Clone, Debug)]
29pub struct PersistentSegtree<M: Monoid> {
30    root: Option<Rc<RefCell<Node<M>>>>,
31    to: usize,
32    original_size: usize,
33}
34
35impl<M: Monoid + Clone> PersistentSegtree<M> {
36    /// 長さ`n`の[`PersistentSegtree`]を生成する。
37    pub fn new(n: usize) -> Self {
38        let seq = vec![M::id(); n];
39        Self::from_vec(seq)
40    }
41
42    /// [`Vec`]から[`PersistentSegtree`]を構築する。
43    pub fn from_vec(a: Vec<M>) -> Self {
44        let n = a.len();
45        let to = n.next_power_of_two();
46        let root = Some(Self::__init(0, to, &a));
47        Self {
48            root,
49            to,
50            original_size: n,
51        }
52    }
53
54    fn __init(from: usize, to: usize, seq: &[M]) -> Rc<RefCell<Node<M>>> {
55        if to - from == 1 {
56            Rc::new(RefCell::new(Node::new(seq[from].clone())))
57        } else {
58            let mid = (from + to) / 2;
59            let mut node = Node::new(M::id());
60
61            let lv = if seq.len() > from {
62                let left = Self::__init(from, mid, seq);
63                let lv = left.borrow().value.clone();
64                node.left = Some(left);
65                lv
66            } else {
67                M::id()
68            };
69
70            let rv = if seq.len() > mid {
71                let right = Self::__init(mid, to, seq);
72                let rv = right.borrow().value.clone();
73                node.right = Some(right);
74                rv
75            } else {
76                M::id()
77            };
78
79            node.value = M::op(lv, rv);
80
81            Rc::new(RefCell::new(node))
82        }
83    }
84
85    fn __set(
86        node: Rc<RefCell<Node<M>>>,
87        from: usize,
88        to: usize,
89        pos: usize,
90        value: &M,
91    ) -> Rc<RefCell<Node<M>>> {
92        if to <= pos || pos < from {
93            node
94        } else if pos <= from && to <= pos + 1 {
95            Rc::new(RefCell::new(Node::new(value.clone())))
96        } else {
97            let mid = (from + to) / 2;
98
99            let lp = node
100                .borrow()
101                .left
102                .clone()
103                .map(|left| Self::__set(left, from, mid, pos, value));
104            let rp = node
105                .borrow()
106                .right
107                .clone()
108                .map(|right| Self::__set(right, mid, to, pos, value));
109
110            let mut s = Node::new(M::op(
111                lp.as_ref().map_or(M::id(), |l| l.borrow().value.clone()),
112                rp.as_ref().map_or(M::id(), |r| r.borrow().value.clone()),
113            ));
114
115            s.left = lp;
116            s.right = rp;
117
118            Rc::new(RefCell::new(s))
119        }
120    }
121
122    /// `i`番目の要素を`value`にする。
123    pub fn assign(&self, i: usize, value: M) -> Self {
124        let new_root = Self::__set(self.root.clone().unwrap(), 0, self.to, i, &value);
125
126        Self {
127            root: Some(new_root),
128            to: self.to,
129            original_size: self.original_size,
130        }
131    }
132
133    fn __fold(node: Rc<RefCell<Node<M>>>, from: usize, to: usize, l: usize, r: usize) -> M {
134        if l <= from && to <= r {
135            node.borrow().value.clone()
136        } else if to <= l || r <= from {
137            M::id()
138        } else {
139            let mid = (from + to) / 2;
140
141            let lv = node
142                .borrow()
143                .left
144                .clone()
145                .map_or(M::id(), |left| Self::__fold(left, from, mid, l, r));
146
147            let rv = node
148                .borrow()
149                .right
150                .clone()
151                .map_or(M::id(), |right| Self::__fold(right, mid, to, l, r));
152
153            M::op(lv, rv)
154        }
155    }
156
157    /// 範囲`range`で計算を集約して返す。
158    pub fn fold(&self, range: impl RangeBounds<usize>) -> M {
159        let (start, end) = range_bounds_to_range(range, 0, self.original_size);
160        Self::__fold(self.root.clone().unwrap(), 0, self.to, start, end)
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::{algebra::sum::*, misc::vec_map::VecMap};
168
169    #[test]
170    fn test() {
171        let a = vec![0, 1, 3, 9, 4, 8, 2];
172        let a = a.map(Sum);
173        let seg = PersistentSegtree::<Sum<u64>>::from_vec(a);
174
175        dbg!(seg.fold(0..5));
176
177        let s1 = seg.assign(0, Sum(10));
178
179        dbg!(s1.fold(0..5));
180        dbg!(seg.fold(0..5));
181
182        let s2 = seg.assign(2, Sum(6));
183
184        dbg!(s1.fold(0..5));
185        dbg!(s2.fold(0..5));
186        dbg!(seg.fold(0..5));
187    }
188}