haar_lib/ds/
persistent_array.rs

1//! 永続配列
2
3use std::cmp::Ordering;
4use std::ops::Index;
5use std::ptr;
6
7enum Value<T> {
8    Value(T),
9    Ptr(*const Node<T>),
10}
11
12struct Node<T> {
13    value: Value<T>,
14    index: usize,
15    left: *const Self,
16    right: *const Self,
17}
18
19impl<T> Node<T> {
20    fn new(value: Value<T>, index: usize, left: *const Self, right: *const Self) -> Self {
21        Self {
22            value,
23            index,
24            left,
25            right,
26        }
27    }
28
29    fn ref_value<'a>(node: *const Self) -> &'a T {
30        assert!(!node.is_null());
31        let node = unsafe { &*node };
32
33        match &node.value {
34            Value::Value(x) => x,
35            &Value::Ptr(p) => {
36                assert!(!p.is_null());
37                match unsafe { &(*p).value } {
38                    Value::Value(x) => x,
39                    _ => unreachable!(),
40                }
41            }
42        }
43    }
44}
45
46/// 永続配列
47#[derive(Clone)]
48pub struct PersistentArray<T> {
49    size: usize,
50    root: *const Node<T>,
51}
52
53impl<T> PersistentArray<T>
54where
55    T: Clone,
56{
57    /// `n`個の`value`からなる永続配列を作る。
58    ///
59    /// **Time complexity** $O(n)$
60    pub fn new(n: usize, value: T) -> Self {
61        Self::from_vec(vec![value; n])
62    }
63
64    fn _traverse(node: *const Node<T>, ret: &mut Vec<T>) {
65        if !node.is_null() {
66            let node = unsafe { &*node };
67            Self::_traverse(node.left, ret);
68            ret.push(Node::ref_value(node).clone());
69            Self::_traverse(node.right, ret);
70        }
71    }
72
73    /// `Vec`へ変換する。
74    ///
75    /// **Time complexity** $O(n)$
76    pub fn into_vec(&self) -> Vec<T> {
77        let mut ret = vec![];
78        Self::_traverse(self.root, &mut ret);
79        ret
80    }
81}
82
83impl<T> PersistentArray<T> {
84    /// `Vec`から永続配列を作る。
85    ///
86    /// **Time complexity** $O(n)$
87    pub fn from_vec(v: Vec<T>) -> Self {
88        if v.is_empty() {
89            Self {
90                size: 0,
91                root: ptr::null_mut(),
92            }
93        } else {
94            let size = v.len();
95
96            let mut a = v
97                .into_iter()
98                .enumerate()
99                .map(|(i, x)| {
100                    Box::into_raw(Box::new(Node::new(
101                        Value::Value(x),
102                        i,
103                        ptr::null(),
104                        ptr::null(),
105                    )))
106                })
107                .collect::<Vec<_>>();
108
109            let max = (size + 1).next_power_of_two();
110
111            let get_par = |i: usize| {
112                let lowest = 1 << i.trailing_zeros();
113                i ^ lowest | (lowest << 1)
114            };
115
116            for i in 0..size {
117                let i = i + 1;
118
119                let mut par = get_par(i);
120                while par <= max {
121                    if par <= size {
122                        let p = unsafe { &mut *a[par - 1] };
123                        if par < i {
124                            p.right = a[i - 1];
125                        } else {
126                            p.left = a[i - 1];
127                        }
128
129                        break;
130                    }
131                    par = get_par(par);
132                }
133            }
134
135            let root = a[(max >> 1) - 1];
136            Self { size, root }
137        }
138    }
139
140    fn _set(prev: *const Node<T>, i: usize, val: T) -> *const Node<T> {
141        assert!(!prev.is_null());
142        let prev = unsafe { &*prev };
143
144        let (left, right, value);
145        match i.cmp(&prev.index) {
146            Ordering::Less => {
147                left = Self::_set(prev.left, i, val);
148                right = prev.right;
149                value = match prev.value {
150                    Value::Value(_) => Value::Ptr(prev as *const _),
151                    Value::Ptr(p) => Value::Ptr(p),
152                };
153            }
154            Ordering::Greater => {
155                left = prev.left;
156                right = Self::_set(prev.right, i, val);
157                value = match prev.value {
158                    Value::Value(_) => Value::Ptr(prev as *const _),
159                    Value::Ptr(p) => Value::Ptr(p),
160                };
161            }
162            Ordering::Equal => {
163                left = prev.left;
164                right = prev.right;
165                value = Value::Value(val);
166            }
167        }
168
169        let node = Box::new(Node::new(value, prev.index, left, right));
170        Box::into_raw(node)
171    }
172
173    /// `i`番目の要素を`value`に変更した永続配列を返す。
174    ///
175    /// **Time complexity** $O(\log n)$
176    pub fn set(&self, i: usize, value: T) -> Self {
177        assert!(
178            i < self.size,
179            "index out of bounds: the len is {} but the index is {i}",
180            self.size,
181        );
182        Self {
183            size: self.size,
184            root: Self::_set(self.root, i, value),
185        }
186    }
187
188    fn _get<'a>(node: *const Node<T>, i: usize) -> &'a T {
189        assert!(!node.is_null());
190        let node = unsafe { &*node };
191
192        match i.cmp(&node.index) {
193            Ordering::Less => Self::_get(node.left, i),
194            Ordering::Greater => Self::_get(node.right, i),
195            Ordering::Equal => Node::ref_value(node),
196        }
197    }
198
199    /// `i`番目の要素を返す。
200    ///
201    /// **Time complexity** $O(\log n)$
202    pub fn get(&self, i: usize) -> &T {
203        assert!(
204            i < self.size,
205            "index out of bounds: the len is {} but the index is {i}",
206            self.size
207        );
208        Self::_get(self.root, i)
209    }
210}
211
212impl<T: Clone> Index<usize> for PersistentArray<T> {
213    type Output = T;
214    fn index(&self, index: usize) -> &Self::Output {
215        self.get(index)
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test() {
225        let a = PersistentArray::<i32>::from_vec(vec![1, 2, 3, 4, 5]);
226        assert_eq!(a.into_vec(), [1, 2, 3, 4, 5]);
227
228        let b = a.set(0, 4);
229        assert_eq!(b.into_vec(), [4, 2, 3, 4, 5]);
230
231        let c = b.set(2, 6);
232        assert_eq!(c.into_vec(), [4, 2, 6, 4, 5]);
233
234        let d = b.set(2, 9);
235        assert_eq!(d.into_vec(), [4, 2, 9, 4, 5]);
236
237        let e = c.set(4, -3);
238
239        assert_eq!(a.into_vec(), [1, 2, 3, 4, 5]);
240        assert_eq!(b.into_vec(), [4, 2, 3, 4, 5]);
241        assert_eq!(c.into_vec(), [4, 2, 6, 4, 5]);
242        assert_eq!(d.into_vec(), [4, 2, 9, 4, 5]);
243        assert_eq!(e.into_vec(), [4, 2, 6, 4, -3]);
244    }
245}