haar_lib/ds/
persistent_array.rs

1//! 永続配列
2
3use std::ops::Index;
4use std::rc::Rc;
5
6#[derive(Clone)]
7enum Node<T> {
8    Terminal {
9        value: T,
10    },
11    Internal {
12        size: usize,
13        l_ch: Option<Rc<Node<T>>>,
14        r_ch: Option<Rc<Node<T>>>,
15    },
16}
17
18/// 永続配列
19#[derive(Clone)]
20pub struct PersistentArray<T> {
21    size: usize,
22    root: Option<Rc<Node<T>>>,
23}
24
25fn get_size<T>(node: &Option<Rc<Node<T>>>) -> usize {
26    node.as_ref().map_or(0, |node| match node.as_ref() {
27        Node::Terminal { .. } => 1,
28        Node::Internal { size, .. } => *size,
29    })
30}
31
32impl<T> PersistentArray<T>
33where
34    T: Clone,
35{
36    /// **Time complexity** $O(n)$
37    pub fn new(size: usize, value: T) -> Self {
38        if size == 0 {
39            Self {
40                size: 0,
41                root: None,
42            }
43        } else {
44            let depth = usize::BITS - (size - 1_usize).leading_zeros() + 1;
45            let values = vec![value; size];
46            let root = Self::_init(0, size, &values, depth);
47
48            Self { size, root }
49        }
50    }
51
52    fn _init(l: usize, r: usize, values: &[T], depth: u32) -> Option<Rc<Node<T>>> {
53        if l == r {
54            return None;
55        }
56        if depth == 1 {
57            Some(Rc::new(Node::Terminal {
58                value: values[l].clone(),
59            }))
60        } else {
61            let mid = (l + r) / 2;
62            let l_ch = Self::_init(l, mid, values, depth - 1);
63            let r_ch = Self::_init(mid, r, values, depth - 1);
64
65            let t = Node::Internal {
66                size: get_size(&l_ch) + get_size(&r_ch),
67                l_ch,
68                r_ch,
69            };
70
71            Some(Rc::new(t))
72        }
73    }
74
75    fn _traverse(node: &Option<Rc<Node<T>>>, ret: &mut Vec<T>) {
76        if let Some(node) = node {
77            match node.as_ref() {
78                Node::Terminal { value } => {
79                    ret.push(value.clone());
80                }
81                Node::Internal { l_ch, r_ch, .. } => {
82                    Self::_traverse(l_ch, ret);
83                    Self::_traverse(r_ch, ret);
84                }
85            }
86        }
87    }
88
89    fn _set(prev: &Rc<Node<T>>, i: usize, value: T) -> Rc<Node<T>> {
90        match prev.as_ref() {
91            Node::Terminal { .. } => Rc::new(Node::Terminal { value }),
92            Node::Internal { l_ch, r_ch, .. } => {
93                let (l_ch, r_ch) = {
94                    let k = get_size(l_ch);
95                    if i < k {
96                        (
97                            Some(Self::_set(l_ch.as_ref().unwrap(), i, value)),
98                            r_ch.clone(),
99                        )
100                    } else {
101                        (
102                            l_ch.clone(),
103                            Some(Self::_set(r_ch.as_ref().unwrap(), i - k, value)),
104                        )
105                    }
106                };
107
108                Rc::new(Node::Internal {
109                    size: get_size(&l_ch) + get_size(&r_ch),
110                    l_ch,
111                    r_ch,
112                })
113            }
114        }
115    }
116
117    /// **Time complexity** $O(\log n)$
118    pub fn set(&self, i: usize, value: T) -> Self {
119        assert!(
120            i < self.size,
121            "index out of bounds: the len is {} but the index is {}",
122            self.size,
123            i
124        );
125        Self {
126            size: self.size,
127            root: Some(Self::_set(self.root.as_ref().unwrap(), i, value)),
128        }
129    }
130
131    fn _get(node: &Rc<Node<T>>, i: usize) -> &Node<T> {
132        match node.as_ref() {
133            Node::Terminal { .. } => node.as_ref(),
134            Node::Internal { l_ch, r_ch, .. } => {
135                let k = get_size(l_ch);
136                if i < k {
137                    Self::_get(l_ch.as_ref().unwrap(), i)
138                } else {
139                    Self::_get(r_ch.as_ref().unwrap(), i - k)
140                }
141            }
142        }
143    }
144
145    /// **Time complexity** $O(\log n)$
146    pub fn get(&self, i: usize) -> &T {
147        assert!(
148            i < self.size,
149            "index out of bounds: the len is {} but the index is {}",
150            self.size,
151            i
152        );
153        match Self::_get(self.root.as_ref().unwrap(), i) {
154            Node::Terminal { value } => value,
155            _ => unreachable!(),
156        }
157    }
158}
159
160impl<T: Clone> Index<usize> for PersistentArray<T> {
161    type Output = T;
162    fn index(&self, index: usize) -> &Self::Output {
163        self.get(index)
164    }
165}
166
167impl<T: Clone> From<&PersistentArray<T>> for Vec<T> {
168    fn from(from: &PersistentArray<T>) -> Vec<T> {
169        let mut ret = vec![];
170        PersistentArray::<T>::_traverse(&from.root, &mut ret);
171        ret
172    }
173}
174
175impl<T: Clone> From<Vec<T>> for PersistentArray<T> {
176    fn from(value: Vec<T>) -> Self {
177        let size = value.len();
178        if size == 0 {
179            Self {
180                size: 0,
181                root: None,
182            }
183        } else {
184            let depth = usize::BITS - (size - 1_usize).leading_zeros() + 1;
185            let root = Self::_init(0, size, &value, depth);
186
187            Self { size, root }
188        }
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn test() {
198        let a = PersistentArray::<i32>::from(vec![1, 2, 3, 4, 5]);
199        assert_eq!(Vec::<i32>::from(&a), [1, 2, 3, 4, 5]);
200
201        let b = a.set(0, 4);
202        assert_eq!(Vec::<i32>::from(&b), [4, 2, 3, 4, 5]);
203
204        let c = b.set(2, 6);
205        assert_eq!(Vec::<i32>::from(&c), [4, 2, 6, 4, 5]);
206
207        let d = b.set(2, 9);
208        assert_eq!(Vec::<i32>::from(&d), [4, 2, 9, 4, 5]);
209
210        let e = c.set(4, -3);
211
212        assert_eq!(Vec::<i32>::from(&a), [1, 2, 3, 4, 5]);
213        assert_eq!(Vec::<i32>::from(&b), [4, 2, 3, 4, 5]);
214        assert_eq!(Vec::<i32>::from(&c), [4, 2, 6, 4, 5]);
215        assert_eq!(Vec::<i32>::from(&d), [4, 2, 9, 4, 5]);
216        assert_eq!(Vec::<i32>::from(&e), [4, 2, 6, 4, -3]);
217    }
218}