1use std::ops::RangeBounds;
4use std::ptr;
5
6use crate::algebra::traits::Monoid;
7use crate::misc::range::range_bounds_to_range;
8
9#[derive(Clone, Debug)]
10struct Node<T> {
11 value: T,
12 left: *mut Self,
13 right: *mut Self,
14}
15
16impl<T> Node<T> {
17 fn new(value: T) -> Self {
18 Self {
19 value,
20 left: ptr::null_mut(),
21 right: ptr::null_mut(),
22 }
23 }
24}
25
26#[derive(Clone, Debug)]
28pub struct PersistentSegtree<M: Monoid> {
29 monoid: M,
30 root: *mut Node<M::Element>,
31 to: usize,
32 original_size: usize,
33}
34
35impl<M: Monoid + Clone> PersistentSegtree<M>
36where
37 M::Element: Clone,
38{
39 pub fn new(monoid: M, n: usize) -> Self {
41 let seq = vec![monoid.id(); n];
42 Self::from_vec(monoid, seq)
43 }
44
45 pub fn from_vec(monoid: M, a: Vec<M::Element>) -> Self {
47 let n = a.len();
48 let to = n.next_power_of_two();
49 let root = Self::__init(&monoid, 0, to, &a);
50 Self {
51 monoid,
52 root,
53 to,
54 original_size: n,
55 }
56 }
57
58 fn __init(monoid: &M, from: usize, to: usize, seq: &[M::Element]) -> *mut Node<M::Element> {
59 if to - from == 1 {
60 Box::into_raw(Box::new(Node::new(seq[from].clone())))
61 } else {
62 let mid = (from + to) / 2;
63 let mut node = Node::new(monoid.id());
64
65 let lv = if seq.len() > from {
66 let left = Self::__init(monoid, from, mid, seq);
67 node.left = left;
68 assert!(!left.is_null());
69 unsafe { (*left).value.clone() }
70 } else {
71 monoid.id()
72 };
73
74 let rv = if seq.len() > mid {
75 let right = Self::__init(monoid, mid, to, seq);
76 node.right = right;
77 assert!(!right.is_null());
78 unsafe { (*right).value.clone() }
79 } else {
80 monoid.id()
81 };
82
83 node.value = monoid.op(lv, rv);
84
85 Box::into_raw(Box::new(node))
86 }
87 }
88
89 fn __set(
90 monoid: &M,
91 node: *mut Node<M::Element>,
92 from: usize,
93 to: usize,
94 pos: usize,
95 value: &M::Element,
96 ) -> *mut Node<M::Element> {
97 assert!(!node.is_null());
98
99 if to <= pos || pos < from {
100 node
101 } else if pos <= from && to <= pos + 1 {
102 Box::into_raw(Box::new(Node::new(value.clone())))
103 } else {
104 let mid = (from + to) / 2;
105
106 let left = unsafe { (*node).left };
107 let right = unsafe { (*node).right };
108
109 let lp = if !left.is_null() {
110 Self::__set(monoid, left, from, mid, pos, value)
111 } else {
112 left
113 };
114
115 let rp = if !right.is_null() {
116 Self::__set(monoid, right, mid, to, pos, value)
117 } else {
118 right
119 };
120
121 let mut value = monoid.id();
122 if !lp.is_null() {
123 value = monoid.op(value, unsafe { (*lp).value.clone() });
124 }
125 if !rp.is_null() {
126 value = monoid.op(value, unsafe { (*rp).value.clone() });
127 }
128
129 let mut s = Node::new(value);
130 s.left = lp;
131 s.right = rp;
132
133 Box::into_raw(Box::new(s))
134 }
135 }
136
137 pub fn assign(&self, i: usize, value: M::Element) -> Self {
139 let new_root = Self::__set(&self.monoid, self.root, 0, self.to, i, &value);
140
141 Self {
142 monoid: self.monoid.clone(),
143 root: new_root,
144 to: self.to,
145 original_size: self.original_size,
146 }
147 }
148
149 fn __fold(
150 monoid: &M,
151 node: *mut Node<M::Element>,
152 from: usize,
153 to: usize,
154 l: usize,
155 r: usize,
156 ) -> M::Element {
157 assert!(!node.is_null());
158
159 if l <= from && to <= r {
160 unsafe { (*node).value.clone() }
161 } else if to <= l || r <= from {
162 monoid.id()
163 } else {
164 let mid = (from + to) / 2;
165
166 let left = unsafe { (*node).left };
167 let right = unsafe { (*node).right };
168
169 let lv = if left.is_null() {
170 monoid.id()
171 } else {
172 Self::__fold(monoid, left, from, mid, l, r)
173 };
174
175 let rv = if right.is_null() {
176 monoid.id()
177 } else {
178 Self::__fold(monoid, right, mid, to, l, r)
179 };
180
181 monoid.op(lv, rv)
182 }
183 }
184
185 pub fn fold(&self, range: impl RangeBounds<usize>) -> M::Element {
187 let (start, end) = range_bounds_to_range(range, 0, self.original_size);
188 Self::__fold(&self.monoid, self.root, 0, self.to, start, end)
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use crate::algebra::sum::*;
196
197 #[test]
198 fn test() {
199 let m = Sum::<u64>::new();
200
201 let a = vec![0, 1, 3, 9, 4, 8, 2];
202 let seg = PersistentSegtree::from_vec(m, a);
203
204 dbg!(seg.fold(0..5));
205
206 let s1 = seg.assign(0, 10);
207
208 dbg!(s1.fold(0..5));
209 dbg!(seg.fold(0..5));
210
211 let s2 = seg.assign(2, 6);
212
213 dbg!(s1.fold(0..5));
214 dbg!(s2.fold(0..5));
215 dbg!(seg.fold(0..5));
216 }
217}