haar_lib/ds/
splay_tree.rs

1//! Splay Tree
2//!
3//! # Problems
4//! - <https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=1508>
5//! - <https://judge.yosupo.jp/problem/range_reverse_range_sum>
6//!
7//! # References
8//! - <https://en.wikipedia.org/wiki/Splay_tree>
9
10use std::cell::Cell;
11use std::cmp::Ordering;
12use std::ops::Range;
13use std::ptr;
14
15use crate::algebra::traits::Monoid;
16
17struct Node<M: Monoid> {
18    value: M,
19    sum: M,
20    size: usize,
21    rev: bool,
22    lc: *mut Node<M>,
23    rc: *mut Node<M>,
24    par: *mut Node<M>,
25}
26
27impl<M: Monoid + Clone> Node<M> {
28    fn new(value: M) -> Self {
29        Self {
30            value,
31            sum: M::id(),
32            size: 1,
33            rev: false,
34            lc: ptr::null_mut(),
35            rc: ptr::null_mut(),
36            par: ptr::null_mut(),
37        }
38    }
39
40    fn get_sum(this: *mut Self) -> M {
41        assert!(!this.is_null());
42        unsafe { (*this).sum.clone() }
43    }
44
45    fn set_value(this: *mut Self, value: M) {
46        assert!(!this.is_null());
47        unsafe {
48            (*this).value = value;
49        }
50    }
51
52    fn rotate(this: *mut Self) {
53        let p = Self::get_par(this).unwrap();
54        let pp = Self::get_par(p).unwrap();
55
56        if Self::left_of(p).unwrap() == this {
57            let c = Self::right_of(this).unwrap();
58            Self::set_left(p, c);
59            Self::set_right(this, p);
60        } else {
61            let c = Self::left_of(this).unwrap();
62            Self::set_right(p, c);
63            Self::set_left(this, p);
64        }
65
66        unsafe {
67            if !pp.is_null() {
68                if (*pp).lc == p {
69                    (*pp).lc = this;
70                }
71                if (*pp).rc == p {
72                    (*pp).rc = this;
73                }
74            }
75
76            assert!(!this.is_null());
77            (*this).par = pp;
78        }
79
80        Self::update(p);
81        Self::update(this);
82    }
83
84    fn status(this: *mut Self) -> i32 {
85        let par = Self::get_par(this).unwrap();
86
87        if par.is_null() {
88            return 0;
89        }
90        if unsafe { (*par).lc } == this {
91            return 1;
92        }
93        if unsafe { (*par).rc } == this {
94            return -1;
95        }
96
97        unreachable!()
98    }
99
100    fn reverse(this: *mut Self) {
101        if !this.is_null() {
102            unsafe {
103                (*this).rev ^= true;
104            }
105        }
106    }
107
108    fn pushdown(this: *mut Self) {
109        if !this.is_null() {
110            unsafe {
111                if (*this).rev {
112                    std::mem::swap(&mut (*this).lc, &mut (*this).rc);
113                    Self::reverse((*this).lc);
114                    Self::reverse((*this).rc);
115                    (*this).rev = false;
116                }
117            }
118            Self::update(this);
119        }
120    }
121
122    fn update(this: *mut Self) {
123        assert!(!this.is_null());
124        unsafe {
125            (*this).size = 1 + Self::size_of((*this).lc) + Self::size_of((*this).rc);
126
127            (*this).sum = (*this).value.clone();
128            if !(*this).lc.is_null() {
129                (*this).sum = M::op(Self::get_sum(this), Self::get_sum((*this).lc));
130            }
131            if !(*this).rc.is_null() {
132                (*this).sum = M::op(Self::get_sum(this), Self::get_sum((*this).rc));
133            }
134        }
135    }
136
137    fn splay(this: *mut Self) {
138        while Self::status(this) != 0 {
139            let par = Self::get_par(this).unwrap();
140
141            if Self::status(par) == 0 {
142                Self::rotate(this);
143            } else if Self::status(this) == Self::status(par) {
144                Self::rotate(par);
145                Self::rotate(this);
146            } else {
147                Self::rotate(this);
148                Self::rotate(this);
149            }
150        }
151    }
152
153    fn get(root: *mut Self, mut index: usize) -> *mut Self {
154        if root.is_null() {
155            return root;
156        }
157
158        let mut cur = root;
159
160        loop {
161            Self::pushdown(cur);
162
163            let left = Self::left_of(cur).unwrap();
164            let lsize = Self::size_of(left);
165
166            match index.cmp(&lsize) {
167                Ordering::Less => {
168                    cur = left;
169                }
170                Ordering::Equal => {
171                    Self::splay(cur);
172                    return cur;
173                }
174                Ordering::Greater => {
175                    cur = Self::right_of(cur).unwrap();
176                    index -= lsize + 1;
177                }
178            }
179        }
180    }
181
182    fn merge(left: *mut Self, right: *mut Self) -> *mut Self {
183        if left.is_null() {
184            return right;
185        }
186        if right.is_null() {
187            return left;
188        }
189
190        let cur = Self::get(left, Self::size_of(left) - 1);
191
192        Self::set_right(cur, right);
193        Self::update(right);
194        Self::update(cur);
195
196        cur
197    }
198
199    fn split(root: *mut Self, index: usize) -> (*mut Self, *mut Self) {
200        if root.is_null() {
201            return (ptr::null_mut(), ptr::null_mut());
202        }
203        if index >= Self::size_of(root) {
204            return (root, ptr::null_mut());
205        }
206
207        let cur = Self::get(root, index);
208        let left = Self::left_of(cur).unwrap();
209
210        if !left.is_null() {
211            unsafe {
212                (*left).par = ptr::null_mut();
213            }
214            Self::update(left);
215        }
216        assert!(!cur.is_null());
217        unsafe {
218            (*cur).lc = ptr::null_mut();
219        }
220        Self::update(cur);
221
222        (left, cur)
223    }
224
225    fn traverse(cur: *mut Self, f: &mut impl FnMut(&M)) {
226        if !cur.is_null() {
227            Self::pushdown(cur);
228            Self::traverse(Self::left_of(cur).unwrap(), f);
229            f(unsafe { &(*cur).value });
230            Self::traverse(Self::right_of(cur).unwrap(), f);
231        }
232    }
233}
234
235impl<M: Monoid> Node<M> {
236    fn set_left(this: *mut Self, left: *mut Self) {
237        assert!(!this.is_null());
238        unsafe {
239            (*this).lc = left;
240            if !left.is_null() {
241                (*left).par = this;
242            }
243        }
244    }
245
246    fn set_right(this: *mut Self, right: *mut Self) {
247        assert!(!this.is_null());
248        unsafe {
249            (*this).rc = right;
250            if !right.is_null() {
251                (*right).par = this;
252            }
253        }
254    }
255
256    fn size_of(this: *mut Self) -> usize {
257        if this.is_null() {
258            0
259        } else {
260            unsafe { (*this).size }
261        }
262    }
263
264    fn left_of(this: *mut Self) -> Option<*mut Self> {
265        (!this.is_null()).then_some(unsafe { (*this).lc })
266    }
267
268    fn right_of(this: *mut Self) -> Option<*mut Self> {
269        (!this.is_null()).then_some(unsafe { (*this).rc })
270    }
271
272    fn get_par(this: *mut Self) -> Option<*mut Self> {
273        (!this.is_null()).then_some(unsafe { (*this).par })
274    }
275
276    fn clear(this: *mut Self) {
277        if !this.is_null() {
278            let lc = Self::left_of(this).unwrap();
279            let rc = Self::right_of(this).unwrap();
280
281            let _ = unsafe { Box::from_raw(this) };
282
283            Self::clear(lc);
284            Self::clear(rc);
285        }
286    }
287}
288
289/// スプレーツリー
290pub struct SplayTree<M: Monoid> {
291    root: Cell<*mut Node<M>>,
292}
293
294impl<M: Monoid + Clone> Default for SplayTree<M> {
295    fn default() -> Self {
296        Self::new()
297    }
298}
299
300impl<M: Monoid + Clone> SplayTree<M> {
301    /// モノイド`m`をもつ`SplayTree<M>`を生成
302    pub fn new() -> Self {
303        Self {
304            root: Cell::new(ptr::null_mut()),
305        }
306    }
307
308    /// 値`value`をもつノード一つのみからなる`SplayTree<M>`を生成
309    pub fn singleton(value: M) -> Self {
310        let root = Box::new(Node::new(value));
311
312        Self {
313            root: Cell::new(Box::into_raw(root)),
314        }
315    }
316
317    /// スプレーツリーの要素数を返す
318    pub fn len(&self) -> usize {
319        Node::size_of(self.root.get())
320    }
321
322    /// スプレーツリーが要素を持たなければ`true`を返す
323    pub fn is_empty(&self) -> bool {
324        self.root.get().is_null()
325    }
326
327    /// `index`番目の要素の参照を返す
328    pub fn get(&self, index: usize) -> Option<&M> {
329        self.root.set(Node::get(self.root.get(), index));
330        let node = self.root.get();
331
332        if node.is_null() {
333            None
334        } else {
335            unsafe { Some(&(*node).value) }
336        }
337    }
338
339    /// `index`番目の要素を`value`に変更する
340    pub fn set(&mut self, index: usize, value: M) {
341        let root = Node::get(self.root.get(), index);
342        Node::set_value(root, value);
343        Node::update(root);
344        self.root.set(root);
345    }
346
347    /// 右側にスプレーツリーを結合する
348    pub fn merge_right(&mut self, right: Self) {
349        let root = Node::merge(self.root.get(), right.root.get());
350        right.root.set(ptr::null_mut());
351        self.root.set(root);
352    }
353
354    /// 左側にスプレーツリーを結合する
355    pub fn merge_left(&mut self, left: Self) {
356        let root = Node::merge(left.root.get(), self.root.get());
357        left.root.set(ptr::null_mut());
358        self.root.set(root);
359    }
360
361    /// 左側に`index`個の要素があるように、左右で分割する
362    pub fn split(self, index: usize) -> (Self, Self) {
363        let (l, r) = Node::split(self.root.get(), index);
364        self.root.set(ptr::null_mut());
365        (Self { root: Cell::new(l) }, Self { root: Cell::new(r) })
366    }
367
368    /// 要素を`index`番目になるように挿入する
369    pub fn insert(&mut self, index: usize, value: M) {
370        let (l, r) = Node::split(self.root.get(), index);
371        let node = Box::into_raw(Box::new(Node::new(value)));
372        let root = Node::merge(l, Node::merge(node, r));
373        self.root.set(root);
374    }
375
376    /// `index`番目の要素を削除して、値を返す
377    pub fn remove(&mut self, index: usize) -> Option<M> {
378        let (l, r) = Node::split(self.root.get(), index);
379        let (m, r) = Node::split(r, 1);
380
381        if m.is_null() {
382            return None;
383        }
384
385        let value = unsafe {
386            let m = Box::from_raw(m);
387            m.value
388        };
389
390        self.root.set(Node::merge(l, r));
391
392        Some(value)
393    }
394
395    /// `start..end`の範囲を反転させる
396    pub fn reverse(&mut self, Range { start, end }: Range<usize>) {
397        let (m, r) = Node::split(self.root.get(), end);
398        let (l, m) = Node::split(m, start);
399
400        Node::reverse(m);
401
402        let m = Node::merge(l, m);
403        let root = Node::merge(m, r);
404        self.root.set(root);
405    }
406
407    /// `start..end`の範囲でのモノイドの演算の結果を返す
408    pub fn fold(&self, Range { start, end }: Range<usize>) -> M {
409        let (m, r) = Node::split(self.root.get(), end);
410        let (l, m) = Node::split(m, start);
411
412        let ret = if m.is_null() {
413            M::id()
414        } else {
415            Node::get_sum(m)
416        };
417
418        let m = Node::merge(l, m);
419        let root = Node::merge(m, r);
420        self.root.set(root);
421
422        ret
423    }
424
425    /// 先頭に値を追加する
426    pub fn push_first(&mut self, value: M) {
427        let left = Self::singleton(value);
428        self.merge_left(left);
429    }
430    /// 末尾に値を追加する
431    pub fn push_last(&mut self, value: M) {
432        let right = Self::singleton(value);
433        self.merge_right(right);
434    }
435    /// 先頭の値を削除する
436    pub fn pop_first(&mut self) -> Option<M> {
437        self.remove(0)
438    }
439    /// 末尾の値を削除する
440    pub fn pop_last(&mut self) -> Option<M> {
441        if self.is_empty() {
442            None
443        } else {
444            self.remove(self.len() - 1)
445        }
446    }
447
448    /// 列の要素を始めから辿り、その参照を`f`に渡す。
449    pub fn for_each(&self, mut f: impl FnMut(&M)) {
450        Node::traverse(self.root.get(), &mut f);
451    }
452}
453
454impl<M: Monoid> std::ops::Drop for SplayTree<M> {
455    fn drop(&mut self) {
456        Node::clear(self.root.get());
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use crate::algebra::sum::*;
463    use my_testtools::rand_range;
464
465    use rand::Rng;
466
467    use super::*;
468
469    #[test]
470    fn test() {
471        let t = 100;
472
473        let mut rng = rand::thread_rng();
474
475        let mut a = vec![];
476        let mut st = SplayTree::<Sum<u64>>::new();
477
478        for _ in 0..t {
479            assert_eq!(a.len(), st.len());
480            let n = a.len();
481
482            let i = rng.gen_range(0..=n);
483            let x = Sum(rng.gen::<u32>() as u64);
484
485            a.insert(i, x);
486            st.insert(i, x);
487
488            assert_eq!(a.len(), st.len());
489            let n = a.len();
490
491            let Range { start: l, end: r } = rand_range(&mut rng, 0..n);
492            assert_eq!(a[l..r].iter().cloned().fold_m(), st.fold(l..r));
493        }
494    }
495}