1use std::ops::Index;
4use std::rc::Rc;
5
6#[derive(Clone)]
7enum Node<T> {
8 Terminal {
9 value: T,
10 },
11 Internal {
12 size: usize,
13 l_ch: Option<Rc<Node<T>>>,
14 r_ch: Option<Rc<Node<T>>>,
15 },
16}
17
18#[derive(Clone)]
20pub struct PersistentArray<T> {
21 size: usize,
22 root: Option<Rc<Node<T>>>,
23}
24
25fn get_size<T>(node: &Option<Rc<Node<T>>>) -> usize {
26 node.as_ref().map_or(0, |node| match node.as_ref() {
27 Node::Terminal { .. } => 1,
28 Node::Internal { size, .. } => *size,
29 })
30}
31
32impl<T> PersistentArray<T>
33where
34 T: Clone,
35{
36 pub fn new(size: usize, value: T) -> Self {
38 if size == 0 {
39 Self {
40 size: 0,
41 root: None,
42 }
43 } else {
44 let depth = usize::BITS - (size - 1_usize).leading_zeros() + 1;
45 let values = vec![value; size];
46 let root = Self::_init(0, size, &values, depth);
47
48 Self { size, root }
49 }
50 }
51
52 fn _init(l: usize, r: usize, values: &[T], depth: u32) -> Option<Rc<Node<T>>> {
53 if l == r {
54 return None;
55 }
56 if depth == 1 {
57 Some(Rc::new(Node::Terminal {
58 value: values[l].clone(),
59 }))
60 } else {
61 let mid = (l + r) / 2;
62 let l_ch = Self::_init(l, mid, values, depth - 1);
63 let r_ch = Self::_init(mid, r, values, depth - 1);
64
65 let t = Node::Internal {
66 size: get_size(&l_ch) + get_size(&r_ch),
67 l_ch,
68 r_ch,
69 };
70
71 Some(Rc::new(t))
72 }
73 }
74
75 fn _traverse(node: &Option<Rc<Node<T>>>, ret: &mut Vec<T>) {
76 if let Some(node) = node {
77 match node.as_ref() {
78 Node::Terminal { value } => {
79 ret.push(value.clone());
80 }
81 Node::Internal { l_ch, r_ch, .. } => {
82 Self::_traverse(l_ch, ret);
83 Self::_traverse(r_ch, ret);
84 }
85 }
86 }
87 }
88
89 fn _set(prev: &Rc<Node<T>>, i: usize, value: T) -> Rc<Node<T>> {
90 match prev.as_ref() {
91 Node::Terminal { .. } => Rc::new(Node::Terminal { value }),
92 Node::Internal { l_ch, r_ch, .. } => {
93 let (l_ch, r_ch) = {
94 let k = get_size(l_ch);
95 if i < k {
96 (
97 Some(Self::_set(l_ch.as_ref().unwrap(), i, value)),
98 r_ch.clone(),
99 )
100 } else {
101 (
102 l_ch.clone(),
103 Some(Self::_set(r_ch.as_ref().unwrap(), i - k, value)),
104 )
105 }
106 };
107
108 Rc::new(Node::Internal {
109 size: get_size(&l_ch) + get_size(&r_ch),
110 l_ch,
111 r_ch,
112 })
113 }
114 }
115 }
116
117 pub fn set(&self, i: usize, value: T) -> Self {
119 assert!(
120 i < self.size,
121 "index out of bounds: the len is {} but the index is {}",
122 self.size,
123 i
124 );
125 Self {
126 size: self.size,
127 root: Some(Self::_set(self.root.as_ref().unwrap(), i, value)),
128 }
129 }
130
131 fn _get(node: &Rc<Node<T>>, i: usize) -> &Node<T> {
132 match node.as_ref() {
133 Node::Terminal { .. } => node.as_ref(),
134 Node::Internal { l_ch, r_ch, .. } => {
135 let k = get_size(l_ch);
136 if i < k {
137 Self::_get(l_ch.as_ref().unwrap(), i)
138 } else {
139 Self::_get(r_ch.as_ref().unwrap(), i - k)
140 }
141 }
142 }
143 }
144
145 pub fn get(&self, i: usize) -> &T {
147 assert!(
148 i < self.size,
149 "index out of bounds: the len is {} but the index is {}",
150 self.size,
151 i
152 );
153 match Self::_get(self.root.as_ref().unwrap(), i) {
154 Node::Terminal { value } => value,
155 _ => unreachable!(),
156 }
157 }
158}
159
160impl<T: Clone> Index<usize> for PersistentArray<T> {
161 type Output = T;
162 fn index(&self, index: usize) -> &Self::Output {
163 self.get(index)
164 }
165}
166
167impl<T: Clone> From<&PersistentArray<T>> for Vec<T> {
168 fn from(from: &PersistentArray<T>) -> Vec<T> {
169 let mut ret = vec![];
170 PersistentArray::<T>::_traverse(&from.root, &mut ret);
171 ret
172 }
173}
174
175impl<T: Clone> From<Vec<T>> for PersistentArray<T> {
176 fn from(value: Vec<T>) -> Self {
177 let size = value.len();
178 if size == 0 {
179 Self {
180 size: 0,
181 root: None,
182 }
183 } else {
184 let depth = usize::BITS - (size - 1_usize).leading_zeros() + 1;
185 let root = Self::_init(0, size, &value, depth);
186
187 Self { size, root }
188 }
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test() {
198 let a = PersistentArray::<i32>::from(vec![1, 2, 3, 4, 5]);
199 assert_eq!(Vec::<i32>::from(&a), [1, 2, 3, 4, 5]);
200
201 let b = a.set(0, 4);
202 assert_eq!(Vec::<i32>::from(&b), [4, 2, 3, 4, 5]);
203
204 let c = b.set(2, 6);
205 assert_eq!(Vec::<i32>::from(&c), [4, 2, 6, 4, 5]);
206
207 let d = b.set(2, 9);
208 assert_eq!(Vec::<i32>::from(&d), [4, 2, 9, 4, 5]);
209
210 let e = c.set(4, -3);
211
212 assert_eq!(Vec::<i32>::from(&a), [1, 2, 3, 4, 5]);
213 assert_eq!(Vec::<i32>::from(&b), [4, 2, 3, 4, 5]);
214 assert_eq!(Vec::<i32>::from(&c), [4, 2, 6, 4, 5]);
215 assert_eq!(Vec::<i32>::from(&d), [4, 2, 9, 4, 5]);
216 assert_eq!(Vec::<i32>::from(&e), [4, 2, 6, 4, -3]);
217 }
218}