1use std::cmp::Ordering;
4use std::ops::Index;
5use std::ptr;
6
7enum Value<T> {
8 Value(T),
9 Ptr(*const Node<T>),
10}
11
12struct Node<T> {
13 value: Value<T>,
14 index: usize,
15 left: *const Self,
16 right: *const Self,
17}
18
19impl<T> Node<T> {
20 fn new(value: Value<T>, index: usize, left: *const Self, right: *const Self) -> Self {
21 Self {
22 value,
23 index,
24 left,
25 right,
26 }
27 }
28
29 fn ref_value<'a>(node: *const Self) -> &'a T {
30 assert!(!node.is_null());
31 let node = unsafe { &*node };
32
33 match &node.value {
34 Value::Value(x) => x,
35 &Value::Ptr(p) => {
36 assert!(!p.is_null());
37 match unsafe { &(*p).value } {
38 Value::Value(x) => x,
39 _ => unreachable!(),
40 }
41 }
42 }
43 }
44}
45
46#[derive(Clone)]
48pub struct PersistentArray<T> {
49 size: usize,
50 root: *const Node<T>,
51}
52
53impl<T> PersistentArray<T>
54where
55 T: Clone,
56{
57 pub fn new(n: usize, value: T) -> Self {
61 Self::from_vec(vec![value; n])
62 }
63
64 fn _traverse(node: *const Node<T>, ret: &mut Vec<T>) {
65 if !node.is_null() {
66 let node = unsafe { &*node };
67 Self::_traverse(node.left, ret);
68 ret.push(Node::ref_value(node).clone());
69 Self::_traverse(node.right, ret);
70 }
71 }
72
73 pub fn into_vec(&self) -> Vec<T> {
77 let mut ret = vec![];
78 Self::_traverse(self.root, &mut ret);
79 ret
80 }
81}
82
83impl<T> PersistentArray<T> {
84 pub fn from_vec(v: Vec<T>) -> Self {
88 if v.is_empty() {
89 Self {
90 size: 0,
91 root: ptr::null_mut(),
92 }
93 } else {
94 let size = v.len();
95
96 let mut a = v
97 .into_iter()
98 .enumerate()
99 .map(|(i, x)| {
100 Box::into_raw(Box::new(Node::new(
101 Value::Value(x),
102 i,
103 ptr::null(),
104 ptr::null(),
105 )))
106 })
107 .collect::<Vec<_>>();
108
109 let max = (size + 1).next_power_of_two();
110
111 let get_par = |i: usize| {
112 let lowest = 1 << i.trailing_zeros();
113 i ^ lowest | (lowest << 1)
114 };
115
116 for i in 0..size {
117 let i = i + 1;
118
119 let mut par = get_par(i);
120 while par <= max {
121 if par <= size {
122 let p = unsafe { &mut *a[par - 1] };
123 if par < i {
124 p.right = a[i - 1];
125 } else {
126 p.left = a[i - 1];
127 }
128
129 break;
130 }
131 par = get_par(par);
132 }
133 }
134
135 let root = a[(max >> 1) - 1];
136 Self { size, root }
137 }
138 }
139
140 fn _set(prev: *const Node<T>, i: usize, val: T) -> *const Node<T> {
141 assert!(!prev.is_null());
142 let prev = unsafe { &*prev };
143
144 let (left, right, value);
145 match i.cmp(&prev.index) {
146 Ordering::Less => {
147 left = Self::_set(prev.left, i, val);
148 right = prev.right;
149 value = match prev.value {
150 Value::Value(_) => Value::Ptr(prev as *const _),
151 Value::Ptr(p) => Value::Ptr(p),
152 };
153 }
154 Ordering::Greater => {
155 left = prev.left;
156 right = Self::_set(prev.right, i, val);
157 value = match prev.value {
158 Value::Value(_) => Value::Ptr(prev as *const _),
159 Value::Ptr(p) => Value::Ptr(p),
160 };
161 }
162 Ordering::Equal => {
163 left = prev.left;
164 right = prev.right;
165 value = Value::Value(val);
166 }
167 }
168
169 let node = Box::new(Node::new(value, prev.index, left, right));
170 Box::into_raw(node)
171 }
172
173 pub fn set(&self, i: usize, value: T) -> Self {
177 assert!(
178 i < self.size,
179 "index out of bounds: the len is {} but the index is {i}",
180 self.size,
181 );
182 Self {
183 size: self.size,
184 root: Self::_set(self.root, i, value),
185 }
186 }
187
188 fn _get<'a>(node: *const Node<T>, i: usize) -> &'a T {
189 assert!(!node.is_null());
190 let node = unsafe { &*node };
191
192 match i.cmp(&node.index) {
193 Ordering::Less => Self::_get(node.left, i),
194 Ordering::Greater => Self::_get(node.right, i),
195 Ordering::Equal => Node::ref_value(node),
196 }
197 }
198
199 pub fn get(&self, i: usize) -> &T {
203 assert!(
204 i < self.size,
205 "index out of bounds: the len is {} but the index is {i}",
206 self.size
207 );
208 Self::_get(self.root, i)
209 }
210}
211
212impl<T: Clone> Index<usize> for PersistentArray<T> {
213 type Output = T;
214 fn index(&self, index: usize) -> &Self::Output {
215 self.get(index)
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test() {
225 let a = PersistentArray::<i32>::from_vec(vec![1, 2, 3, 4, 5]);
226 assert_eq!(a.into_vec(), [1, 2, 3, 4, 5]);
227
228 let b = a.set(0, 4);
229 assert_eq!(b.into_vec(), [4, 2, 3, 4, 5]);
230
231 let c = b.set(2, 6);
232 assert_eq!(c.into_vec(), [4, 2, 6, 4, 5]);
233
234 let d = b.set(2, 9);
235 assert_eq!(d.into_vec(), [4, 2, 9, 4, 5]);
236
237 let e = c.set(4, -3);
238
239 assert_eq!(a.into_vec(), [1, 2, 3, 4, 5]);
240 assert_eq!(b.into_vec(), [4, 2, 3, 4, 5]);
241 assert_eq!(c.into_vec(), [4, 2, 6, 4, 5]);
242 assert_eq!(d.into_vec(), [4, 2, 9, 4, 5]);
243 assert_eq!(e.into_vec(), [4, 2, 6, 4, -3]);
244 }
245}