haar_lib/ds/
splay_tree.rs

1//! Splay Tree
2//!
3//! # Problems
4//! - <https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=1508>
5//! - <https://judge.yosupo.jp/problem/range_reverse_range_sum>
6//!
7//! # References
8//! - <https://en.wikipedia.org/wiki/Splay_tree>
9
10use std::cell::Cell;
11use std::cmp::Ordering;
12use std::ops::Range;
13use std::ptr;
14
15use crate::algebra::traits::Monoid;
16
17#[derive(Clone)]
18struct Manager<M> {
19    monoid: M,
20}
21
22struct Node<M: Monoid> {
23    value: M::Element,
24    sum: M::Element,
25    size: usize,
26    rev: bool,
27    lc: *mut Self,
28    rc: *mut Self,
29    par: *mut Self,
30}
31
32impl<M: Monoid> Manager<M>
33where
34    M::Element: Clone,
35{
36    fn new(monoid: M) -> Self {
37        Self { monoid }
38    }
39
40    fn create(&self, value: M::Element) -> Node<M> {
41        Node {
42            value,
43            sum: self.monoid.id(),
44            size: 1,
45            rev: false,
46            lc: ptr::null_mut(),
47            rc: ptr::null_mut(),
48            par: ptr::null_mut(),
49        }
50    }
51
52    fn get_sum(&self, this: *mut Node<M>) -> M::Element {
53        if this.is_null() {
54            self.monoid.id()
55        } else {
56            unsafe { (*this).sum.clone() }
57        }
58    }
59
60    fn set_value(&self, this: *mut Node<M>, value: M::Element) {
61        assert!(!this.is_null());
62        unsafe {
63            (*this).value = value;
64        }
65    }
66
67    fn rotate(&self, this: *mut Node<M>) {
68        let p = self.par_of(this).unwrap();
69        let pp = self.par_of(p).unwrap();
70
71        if self.left_of(p).unwrap() == this {
72            let c = self.right_of(this).unwrap();
73            self.set_left(p, c);
74            self.set_right(this, p);
75        } else {
76            let c = self.left_of(this).unwrap();
77            self.set_right(p, c);
78            self.set_left(this, p);
79        }
80
81        if !pp.is_null() {
82            let pp = unsafe { &mut *pp };
83            if pp.lc == p {
84                pp.lc = this;
85            }
86            if pp.rc == p {
87                pp.rc = this;
88            }
89        }
90
91        assert!(!this.is_null());
92        let this = unsafe { &mut *this };
93        this.par = pp;
94
95        self.update(p);
96        self.update(this);
97    }
98
99    fn status(&self, this: *mut Node<M>) -> i32 {
100        let par = self.par_of(this).unwrap();
101        if par.is_null() {
102            return 0;
103        }
104        let par = unsafe { &*par };
105        if par.lc == this {
106            return 1;
107        }
108        if par.rc == this {
109            return -1;
110        }
111
112        unreachable!()
113    }
114
115    fn reverse(&self, this: *mut Node<M>) {
116        if !this.is_null() {
117            unsafe {
118                (*this).rev ^= true;
119            }
120        }
121    }
122
123    fn pushdown(&self, this: *mut Node<M>) {
124        if !this.is_null() {
125            let this = unsafe { &mut *this };
126            if this.rev {
127                std::mem::swap(&mut this.lc, &mut this.rc);
128                self.reverse(this.lc);
129                self.reverse(this.rc);
130                this.rev = false;
131            }
132            self.update(this);
133        }
134    }
135
136    fn update(&self, this: *mut Node<M>) {
137        assert!(!this.is_null());
138        let this = unsafe { &mut *this };
139        this.size = 1 + self.size_of(this.lc) + self.size_of(this.rc);
140
141        this.sum = this.value.clone();
142        if !this.lc.is_null() {
143            this.sum = self.monoid.op(this.sum.clone(), self.get_sum(this.lc));
144        }
145        if !this.rc.is_null() {
146            this.sum = self.monoid.op(this.sum.clone(), self.get_sum(this.rc));
147        }
148    }
149
150    fn splay(&self, this: *mut Node<M>) {
151        while self.status(this) != 0 {
152            let par = self.par_of(this).unwrap();
153
154            if self.status(par) == 0 {
155                self.rotate(this);
156            } else if self.status(this) == self.status(par) {
157                self.rotate(par);
158                self.rotate(this);
159            } else {
160                self.rotate(this);
161                self.rotate(this);
162            }
163        }
164    }
165
166    fn get(&self, root: *mut Node<M>, mut index: usize) -> *mut Node<M> {
167        if root.is_null() {
168            return root;
169        }
170
171        let mut cur = root;
172
173        loop {
174            self.pushdown(cur);
175
176            let left = self.left_of(cur).unwrap();
177            let lsize = self.size_of(left);
178
179            match index.cmp(&lsize) {
180                Ordering::Less => {
181                    cur = left;
182                }
183                Ordering::Equal => {
184                    self.splay(cur);
185                    return cur;
186                }
187                Ordering::Greater => {
188                    cur = self.right_of(cur).unwrap();
189                    index -= lsize + 1;
190                }
191            }
192        }
193    }
194
195    fn merge(&self, left: *mut Node<M>, right: *mut Node<M>) -> *mut Node<M> {
196        if left.is_null() {
197            return right;
198        }
199        if right.is_null() {
200            return left;
201        }
202
203        let cur = self.get(left, self.size_of(left) - 1);
204
205        self.set_right(cur, right);
206        self.update(right);
207        self.update(cur);
208
209        cur
210    }
211
212    fn split(&self, root: *mut Node<M>, index: usize) -> (*mut Node<M>, *mut Node<M>) {
213        if root.is_null() {
214            return (ptr::null_mut(), ptr::null_mut());
215        }
216        if index >= self.size_of(root) {
217            return (root, ptr::null_mut());
218        }
219
220        let cur = self.get(root, index);
221        let left = self.left_of(cur).unwrap();
222
223        if !left.is_null() {
224            self.set_par(left, ptr::null_mut());
225            self.update(left);
226        }
227        self.set_left(cur, ptr::null_mut());
228        self.update(cur);
229
230        (left, cur)
231    }
232
233    fn traverse(&self, cur: *mut Node<M>, f: &mut impl FnMut(&M::Element)) {
234        if !cur.is_null() {
235            let cur = unsafe { &mut *cur };
236            self.pushdown(cur);
237            self.traverse(self.left_of(cur).unwrap(), f);
238            f(&cur.value);
239            self.traverse(self.right_of(cur).unwrap(), f);
240        }
241    }
242}
243
244impl<M: Monoid> Manager<M> {
245    fn set_left(&self, this: *mut Node<M>, left: *mut Node<M>) {
246        assert!(!this.is_null());
247        unsafe { (*this).lc = left };
248        if !left.is_null() {
249            unsafe { (*left).par = this };
250        }
251    }
252
253    fn set_right(&self, this: *mut Node<M>, right: *mut Node<M>) {
254        assert!(!this.is_null());
255        unsafe { (*this).rc = right };
256        if !right.is_null() {
257            unsafe { (*right).par = this };
258        }
259    }
260
261    fn set_par(&self, this: *mut Node<M>, par: *mut Node<M>) {
262        assert!(!this.is_null());
263        unsafe { (*this).par = par };
264    }
265
266    fn size_of(&self, this: *mut Node<M>) -> usize {
267        if this.is_null() {
268            0
269        } else {
270            unsafe { (*this).size }
271        }
272    }
273
274    fn left_of(&self, this: *mut Node<M>) -> Option<*mut Node<M>> {
275        (!this.is_null()).then(|| unsafe { (*this).lc })
276    }
277
278    fn right_of(&self, this: *mut Node<M>) -> Option<*mut Node<M>> {
279        (!this.is_null()).then(|| unsafe { (*this).rc })
280    }
281
282    fn par_of(&self, this: *mut Node<M>) -> Option<*mut Node<M>> {
283        (!this.is_null()).then(|| unsafe { (*this).par })
284    }
285
286    fn clear(&self, this: *mut Node<M>) {
287        if !this.is_null() {
288            let lc = self.left_of(this).unwrap();
289            let rc = self.right_of(this).unwrap();
290
291            let _ = unsafe { Box::from_raw(this) };
292
293            self.clear(lc);
294            self.clear(rc);
295        }
296    }
297}
298
299/// スプレーツリー
300pub struct SplayTree<M: Monoid> {
301    man: Manager<M>,
302    root: Cell<*mut Node<M>>,
303}
304
305impl<M: Monoid + Clone> SplayTree<M>
306where
307    M::Element: Clone,
308{
309    /// モノイド`m`をもつ`SplayTree<M>`を生成
310    pub fn new(monoid: M) -> Self {
311        Self {
312            man: Manager::new(monoid),
313            root: Cell::new(ptr::null_mut()),
314        }
315    }
316
317    /// 値`value`をもつノード一つのみからなる`SplayTree<M>`を生成
318    pub fn singleton(monoid: M, value: M::Element) -> Self {
319        let man = Manager::new(monoid);
320        let root = Box::new(man.create(value));
321
322        Self {
323            man,
324            root: Cell::new(Box::into_raw(root)),
325        }
326    }
327
328    /// スプレーツリーの要素数を返す
329    pub fn len(&self) -> usize {
330        self.man.size_of(self.root.get())
331    }
332
333    /// スプレーツリーが要素を持たなければ`true`を返す
334    pub fn is_empty(&self) -> bool {
335        self.root.get().is_null()
336    }
337
338    /// `index`番目の要素の参照を返す
339    pub fn get(&self, index: usize) -> Option<&M::Element> {
340        self.root.set(self.man.get(self.root.get(), index));
341        let node = self.root.get();
342
343        (!node.is_null()).then(|| unsafe { &(*node).value })
344    }
345
346    /// `index`番目の要素を`value`に変更する
347    pub fn set(&mut self, index: usize, value: M::Element) {
348        let root = self.man.get(self.root.get(), index);
349        self.man.set_value(root, value);
350        self.man.update(root);
351        self.root.set(root);
352    }
353
354    /// 右側にスプレーツリーを結合する
355    pub fn merge_right(&mut self, right: Self) {
356        let root = self.man.merge(self.root.get(), right.root.get());
357        right.root.set(ptr::null_mut());
358        self.root.set(root);
359    }
360
361    /// 左側にスプレーツリーを結合する
362    pub fn merge_left(&mut self, left: Self) {
363        let root = self.man.merge(left.root.get(), self.root.get());
364        left.root.set(ptr::null_mut());
365        self.root.set(root);
366    }
367
368    /// 左側に`index`個の要素があるように、左右で分割する
369    pub fn split(self, index: usize) -> (Self, Self) {
370        let (l, r) = self.man.split(self.root.get(), index);
371        self.root.set(ptr::null_mut());
372        (
373            Self {
374                man: self.man.clone(),
375                root: Cell::new(l),
376            },
377            Self {
378                man: self.man.clone(),
379                root: Cell::new(r),
380            },
381        )
382    }
383
384    /// 要素を`index`番目になるように挿入する
385    pub fn insert(&mut self, index: usize, value: M::Element) {
386        let (l, r) = self.man.split(self.root.get(), index);
387        let node = Box::into_raw(Box::new(self.man.create(value)));
388        let root = self.man.merge(l, self.man.merge(node, r));
389        self.root.set(root);
390    }
391
392    /// `index`番目の要素を削除して、値を返す
393    pub fn remove(&mut self, index: usize) -> Option<M::Element> {
394        let (l, r) = self.man.split(self.root.get(), index);
395        let (m, r) = self.man.split(r, 1);
396
397        self.root.set(self.man.merge(l, r));
398
399        (!m.is_null()).then(|| {
400            let m = unsafe { Box::from_raw(m) };
401            m.value
402        })
403    }
404
405    /// `start..end`の範囲を反転させる
406    pub fn reverse(&mut self, Range { start, end }: Range<usize>) {
407        let (m, r) = self.man.split(self.root.get(), end);
408        let (l, m) = self.man.split(m, start);
409
410        self.man.reverse(m);
411
412        let m = self.man.merge(l, m);
413        let root = self.man.merge(m, r);
414        self.root.set(root);
415    }
416
417    /// `start..end`の範囲でのモノイドの演算の結果を返す
418    pub fn fold(&self, Range { start, end }: Range<usize>) -> M::Element {
419        let (m, r) = self.man.split(self.root.get(), end);
420        let (l, m) = self.man.split(m, start);
421
422        let ret = self.man.get_sum(m);
423
424        let m = self.man.merge(l, m);
425        let root = self.man.merge(m, r);
426        self.root.set(root);
427
428        ret
429    }
430
431    /// 先頭に値を追加する
432    pub fn push_first(&mut self, value: M::Element) {
433        let left = Self::singleton(self.man.monoid.clone(), value);
434        self.merge_left(left);
435    }
436    /// 末尾に値を追加する
437    pub fn push_last(&mut self, value: M::Element) {
438        let right = Self::singleton(self.man.monoid.clone(), value);
439        self.merge_right(right);
440    }
441    /// 先頭の値を削除する
442    pub fn pop_first(&mut self) -> Option<M::Element> {
443        self.remove(0)
444    }
445    /// 末尾の値を削除する
446    pub fn pop_last(&mut self) -> Option<M::Element> {
447        self.remove(self.len().checked_sub(1)?)
448    }
449
450    /// 列の要素を始めから辿り、その参照を`f`に渡す。
451    pub fn for_each(&self, mut f: impl FnMut(&M::Element)) {
452        self.man.traverse(self.root.get(), &mut f);
453    }
454}
455
456impl<M: Monoid> std::ops::Drop for SplayTree<M> {
457    fn drop(&mut self) {
458        self.man.clear(self.root.get());
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use crate::algebra::sum::*;
465    use my_testtools::rand_range;
466
467    use rand::Rng;
468
469    use super::*;
470
471    #[test]
472    fn test_empty() {
473        let monoid = Sum::<u64>::new();
474        let mut s = SplayTree::new(monoid);
475
476        assert!(s.pop_first().is_none());
477        assert!(s.pop_last().is_none());
478    }
479
480    #[test]
481    fn test() {
482        let t = 100;
483
484        let mut rng = rand::thread_rng();
485
486        let m = Sum::<u64>::new();
487        let mut a = vec![];
488        let mut st = SplayTree::new(m);
489
490        for _ in 0..t {
491            assert_eq!(a.len(), st.len());
492            let n = a.len();
493
494            let i = rng.gen_range(0..=n);
495            let x = rng.gen::<u32>() as u64;
496
497            a.insert(i, x);
498            st.insert(i, x);
499
500            assert_eq!(a.len(), st.len());
501            let n = a.len();
502
503            let Range { start: l, end: r } = rand_range(&mut rng, 0..n);
504            assert_eq!(a[l..r].iter().cloned().fold_m(&m), st.fold(l..r));
505        }
506    }
507}