haar_lib/ds/
persistent_segtree.rs1use std::cell::RefCell;
4use std::ops::RangeBounds;
5use std::rc::Rc;
6
7use crate::algebra::traits::Monoid;
8use crate::misc::range::range_bounds_to_range;
9
10#[derive(Clone, Debug)]
11struct Node<T> {
12 value: T,
13 left: Option<Rc<RefCell<Node<T>>>>,
14 right: Option<Rc<RefCell<Node<T>>>>,
15}
16
17impl<T> Node<T> {
18 fn new(value: T) -> Self {
19 Self {
20 value,
21 left: None,
22 right: None,
23 }
24 }
25}
26
27#[derive(Clone, Debug)]
29pub struct PersistentSegtree<M: Monoid> {
30 root: Option<Rc<RefCell<Node<M>>>>,
31 to: usize,
32 original_size: usize,
33}
34
35impl<M: Monoid + Clone> PersistentSegtree<M> {
36 pub fn new(n: usize) -> Self {
38 let seq = vec![M::id(); n];
39 Self::from_vec(seq)
40 }
41
42 pub fn from_vec(a: Vec<M>) -> Self {
44 let n = a.len();
45 let to = n.next_power_of_two();
46 let root = Some(Self::__init(0, to, &a));
47 Self {
48 root,
49 to,
50 original_size: n,
51 }
52 }
53
54 fn __init(from: usize, to: usize, seq: &[M]) -> Rc<RefCell<Node<M>>> {
55 if to - from == 1 {
56 Rc::new(RefCell::new(Node::new(seq[from].clone())))
57 } else {
58 let mid = (from + to) / 2;
59 let mut node = Node::new(M::id());
60
61 let lv = if seq.len() > from {
62 let left = Self::__init(from, mid, seq);
63 let lv = left.borrow().value.clone();
64 node.left = Some(left);
65 lv
66 } else {
67 M::id()
68 };
69
70 let rv = if seq.len() > mid {
71 let right = Self::__init(mid, to, seq);
72 let rv = right.borrow().value.clone();
73 node.right = Some(right);
74 rv
75 } else {
76 M::id()
77 };
78
79 node.value = M::op(lv, rv);
80
81 Rc::new(RefCell::new(node))
82 }
83 }
84
85 fn __set(
86 node: Rc<RefCell<Node<M>>>,
87 from: usize,
88 to: usize,
89 pos: usize,
90 value: &M,
91 ) -> Rc<RefCell<Node<M>>> {
92 if to <= pos || pos < from {
93 node
94 } else if pos <= from && to <= pos + 1 {
95 Rc::new(RefCell::new(Node::new(value.clone())))
96 } else {
97 let mid = (from + to) / 2;
98
99 let lp = node
100 .borrow()
101 .left
102 .clone()
103 .map(|left| Self::__set(left, from, mid, pos, value));
104 let rp = node
105 .borrow()
106 .right
107 .clone()
108 .map(|right| Self::__set(right, mid, to, pos, value));
109
110 let mut s = Node::new(M::op(
111 lp.as_ref().map_or(M::id(), |l| l.borrow().value.clone()),
112 rp.as_ref().map_or(M::id(), |r| r.borrow().value.clone()),
113 ));
114
115 s.left = lp;
116 s.right = rp;
117
118 Rc::new(RefCell::new(s))
119 }
120 }
121
122 pub fn assign(&self, i: usize, value: M) -> Self {
124 let new_root = Self::__set(self.root.clone().unwrap(), 0, self.to, i, &value);
125
126 Self {
127 root: Some(new_root),
128 to: self.to,
129 original_size: self.original_size,
130 }
131 }
132
133 fn __fold(node: Rc<RefCell<Node<M>>>, from: usize, to: usize, l: usize, r: usize) -> M {
134 if l <= from && to <= r {
135 node.borrow().value.clone()
136 } else if to <= l || r <= from {
137 M::id()
138 } else {
139 let mid = (from + to) / 2;
140
141 let lv = node
142 .borrow()
143 .left
144 .clone()
145 .map_or(M::id(), |left| Self::__fold(left, from, mid, l, r));
146
147 let rv = node
148 .borrow()
149 .right
150 .clone()
151 .map_or(M::id(), |right| Self::__fold(right, mid, to, l, r));
152
153 M::op(lv, rv)
154 }
155 }
156
157 pub fn fold(&self, range: impl RangeBounds<usize>) -> M {
159 let (start, end) = range_bounds_to_range(range, 0, self.original_size);
160 Self::__fold(self.root.clone().unwrap(), 0, self.to, start, end)
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::{algebra::sum::*, misc::vec_map::VecMap};
168
169 #[test]
170 fn test() {
171 let a = vec![0, 1, 3, 9, 4, 8, 2];
172 let a = a.map(Sum);
173 let seg = PersistentSegtree::<Sum<u64>>::from_vec(a);
174
175 dbg!(seg.fold(0..5));
176
177 let s1 = seg.assign(0, Sum(10));
178
179 dbg!(s1.fold(0..5));
180 dbg!(seg.fold(0..5));
181
182 let s2 = seg.assign(2, Sum(6));
183
184 dbg!(s1.fold(0..5));
185 dbg!(s2.fold(0..5));
186 dbg!(seg.fold(0..5));
187 }
188}