1use std::cell::Cell;
11use std::cmp::Ordering;
12use std::ops::Range;
13use std::ptr;
14
15use crate::algebra::traits::Monoid;
16
17#[derive(Clone)]
18struct Manager<M> {
19 monoid: M,
20}
21
22struct Node<M: Monoid> {
23 value: M::Element,
24 sum: M::Element,
25 size: usize,
26 rev: bool,
27 lc: *mut Self,
28 rc: *mut Self,
29 par: *mut Self,
30}
31
32impl<M: Monoid> Manager<M>
33where
34 M::Element: Clone,
35{
36 fn new(monoid: M) -> Self {
37 Self { monoid }
38 }
39
40 fn create(&self, value: M::Element) -> Node<M> {
41 Node {
42 value,
43 sum: self.monoid.id(),
44 size: 1,
45 rev: false,
46 lc: ptr::null_mut(),
47 rc: ptr::null_mut(),
48 par: ptr::null_mut(),
49 }
50 }
51
52 fn get_sum(&self, this: *mut Node<M>) -> M::Element {
53 if this.is_null() {
54 self.monoid.id()
55 } else {
56 unsafe { (*this).sum.clone() }
57 }
58 }
59
60 fn set_value(&self, this: *mut Node<M>, value: M::Element) {
61 assert!(!this.is_null());
62 unsafe {
63 (*this).value = value;
64 }
65 }
66
67 fn rotate(&self, this: *mut Node<M>) {
68 let p = self.par_of(this).unwrap();
69 let pp = self.par_of(p).unwrap();
70
71 if self.left_of(p).unwrap() == this {
72 let c = self.right_of(this).unwrap();
73 self.set_left(p, c);
74 self.set_right(this, p);
75 } else {
76 let c = self.left_of(this).unwrap();
77 self.set_right(p, c);
78 self.set_left(this, p);
79 }
80
81 if !pp.is_null() {
82 let pp = unsafe { &mut *pp };
83 if pp.lc == p {
84 pp.lc = this;
85 }
86 if pp.rc == p {
87 pp.rc = this;
88 }
89 }
90
91 assert!(!this.is_null());
92 let this = unsafe { &mut *this };
93 this.par = pp;
94
95 self.update(p);
96 self.update(this);
97 }
98
99 fn status(&self, this: *mut Node<M>) -> i32 {
100 let par = self.par_of(this).unwrap();
101 if par.is_null() {
102 return 0;
103 }
104 let par = unsafe { &*par };
105 if par.lc == this {
106 return 1;
107 }
108 if par.rc == this {
109 return -1;
110 }
111
112 unreachable!()
113 }
114
115 fn reverse(&self, this: *mut Node<M>) {
116 if !this.is_null() {
117 unsafe {
118 (*this).rev ^= true;
119 }
120 }
121 }
122
123 fn pushdown(&self, this: *mut Node<M>) {
124 if !this.is_null() {
125 let this = unsafe { &mut *this };
126 if this.rev {
127 std::mem::swap(&mut this.lc, &mut this.rc);
128 self.reverse(this.lc);
129 self.reverse(this.rc);
130 this.rev = false;
131 }
132 self.update(this);
133 }
134 }
135
136 fn update(&self, this: *mut Node<M>) {
137 assert!(!this.is_null());
138 let this = unsafe { &mut *this };
139 this.size = 1 + self.size_of(this.lc) + self.size_of(this.rc);
140
141 this.sum = this.value.clone();
142 if !this.lc.is_null() {
143 this.sum = self.monoid.op(this.sum.clone(), self.get_sum(this.lc));
144 }
145 if !this.rc.is_null() {
146 this.sum = self.monoid.op(this.sum.clone(), self.get_sum(this.rc));
147 }
148 }
149
150 fn splay(&self, this: *mut Node<M>) {
151 while self.status(this) != 0 {
152 let par = self.par_of(this).unwrap();
153
154 if self.status(par) == 0 {
155 self.rotate(this);
156 } else if self.status(this) == self.status(par) {
157 self.rotate(par);
158 self.rotate(this);
159 } else {
160 self.rotate(this);
161 self.rotate(this);
162 }
163 }
164 }
165
166 fn get(&self, root: *mut Node<M>, mut index: usize) -> *mut Node<M> {
167 if root.is_null() {
168 return root;
169 }
170
171 let mut cur = root;
172
173 loop {
174 self.pushdown(cur);
175
176 let left = self.left_of(cur).unwrap();
177 let lsize = self.size_of(left);
178
179 match index.cmp(&lsize) {
180 Ordering::Less => {
181 cur = left;
182 }
183 Ordering::Equal => {
184 self.splay(cur);
185 return cur;
186 }
187 Ordering::Greater => {
188 cur = self.right_of(cur).unwrap();
189 index -= lsize + 1;
190 }
191 }
192 }
193 }
194
195 fn merge(&self, left: *mut Node<M>, right: *mut Node<M>) -> *mut Node<M> {
196 if left.is_null() {
197 return right;
198 }
199 if right.is_null() {
200 return left;
201 }
202
203 let cur = self.get(left, self.size_of(left) - 1);
204
205 self.set_right(cur, right);
206 self.update(right);
207 self.update(cur);
208
209 cur
210 }
211
212 fn split(&self, root: *mut Node<M>, index: usize) -> (*mut Node<M>, *mut Node<M>) {
213 if root.is_null() {
214 return (ptr::null_mut(), ptr::null_mut());
215 }
216 if index >= self.size_of(root) {
217 return (root, ptr::null_mut());
218 }
219
220 let cur = self.get(root, index);
221 let left = self.left_of(cur).unwrap();
222
223 if !left.is_null() {
224 self.set_par(left, ptr::null_mut());
225 self.update(left);
226 }
227 self.set_left(cur, ptr::null_mut());
228 self.update(cur);
229
230 (left, cur)
231 }
232
233 fn traverse(&self, cur: *mut Node<M>, f: &mut impl FnMut(&M::Element)) {
234 if !cur.is_null() {
235 let cur = unsafe { &mut *cur };
236 self.pushdown(cur);
237 self.traverse(self.left_of(cur).unwrap(), f);
238 f(&cur.value);
239 self.traverse(self.right_of(cur).unwrap(), f);
240 }
241 }
242}
243
244impl<M: Monoid> Manager<M> {
245 fn set_left(&self, this: *mut Node<M>, left: *mut Node<M>) {
246 assert!(!this.is_null());
247 unsafe { (*this).lc = left };
248 if !left.is_null() {
249 unsafe { (*left).par = this };
250 }
251 }
252
253 fn set_right(&self, this: *mut Node<M>, right: *mut Node<M>) {
254 assert!(!this.is_null());
255 unsafe { (*this).rc = right };
256 if !right.is_null() {
257 unsafe { (*right).par = this };
258 }
259 }
260
261 fn set_par(&self, this: *mut Node<M>, par: *mut Node<M>) {
262 assert!(!this.is_null());
263 unsafe { (*this).par = par };
264 }
265
266 fn size_of(&self, this: *mut Node<M>) -> usize {
267 if this.is_null() {
268 0
269 } else {
270 unsafe { (*this).size }
271 }
272 }
273
274 fn left_of(&self, this: *mut Node<M>) -> Option<*mut Node<M>> {
275 (!this.is_null()).then(|| unsafe { (*this).lc })
276 }
277
278 fn right_of(&self, this: *mut Node<M>) -> Option<*mut Node<M>> {
279 (!this.is_null()).then(|| unsafe { (*this).rc })
280 }
281
282 fn par_of(&self, this: *mut Node<M>) -> Option<*mut Node<M>> {
283 (!this.is_null()).then(|| unsafe { (*this).par })
284 }
285
286 fn clear(&self, this: *mut Node<M>) {
287 if !this.is_null() {
288 let lc = self.left_of(this).unwrap();
289 let rc = self.right_of(this).unwrap();
290
291 let _ = unsafe { Box::from_raw(this) };
292
293 self.clear(lc);
294 self.clear(rc);
295 }
296 }
297}
298
299pub struct SplayTree<M: Monoid> {
301 man: Manager<M>,
302 root: Cell<*mut Node<M>>,
303}
304
305impl<M: Monoid + Clone> SplayTree<M>
306where
307 M::Element: Clone,
308{
309 pub fn new(monoid: M) -> Self {
311 Self {
312 man: Manager::new(monoid),
313 root: Cell::new(ptr::null_mut()),
314 }
315 }
316
317 pub fn singleton(monoid: M, value: M::Element) -> Self {
319 let man = Manager::new(monoid);
320 let root = Box::new(man.create(value));
321
322 Self {
323 man,
324 root: Cell::new(Box::into_raw(root)),
325 }
326 }
327
328 pub fn len(&self) -> usize {
330 self.man.size_of(self.root.get())
331 }
332
333 pub fn is_empty(&self) -> bool {
335 self.root.get().is_null()
336 }
337
338 pub fn get(&self, index: usize) -> Option<&M::Element> {
340 self.root.set(self.man.get(self.root.get(), index));
341 let node = self.root.get();
342
343 (!node.is_null()).then(|| unsafe { &(*node).value })
344 }
345
346 pub fn set(&mut self, index: usize, value: M::Element) {
348 let root = self.man.get(self.root.get(), index);
349 self.man.set_value(root, value);
350 self.man.update(root);
351 self.root.set(root);
352 }
353
354 pub fn merge_right(&mut self, right: Self) {
356 let root = self.man.merge(self.root.get(), right.root.get());
357 right.root.set(ptr::null_mut());
358 self.root.set(root);
359 }
360
361 pub fn merge_left(&mut self, left: Self) {
363 let root = self.man.merge(left.root.get(), self.root.get());
364 left.root.set(ptr::null_mut());
365 self.root.set(root);
366 }
367
368 pub fn split(self, index: usize) -> (Self, Self) {
370 let (l, r) = self.man.split(self.root.get(), index);
371 self.root.set(ptr::null_mut());
372 (
373 Self {
374 man: self.man.clone(),
375 root: Cell::new(l),
376 },
377 Self {
378 man: self.man.clone(),
379 root: Cell::new(r),
380 },
381 )
382 }
383
384 pub fn insert(&mut self, index: usize, value: M::Element) {
386 let (l, r) = self.man.split(self.root.get(), index);
387 let node = Box::into_raw(Box::new(self.man.create(value)));
388 let root = self.man.merge(l, self.man.merge(node, r));
389 self.root.set(root);
390 }
391
392 pub fn remove(&mut self, index: usize) -> Option<M::Element> {
394 let (l, r) = self.man.split(self.root.get(), index);
395 let (m, r) = self.man.split(r, 1);
396
397 self.root.set(self.man.merge(l, r));
398
399 (!m.is_null()).then(|| {
400 let m = unsafe { Box::from_raw(m) };
401 m.value
402 })
403 }
404
405 pub fn reverse(&mut self, Range { start, end }: Range<usize>) {
407 let (m, r) = self.man.split(self.root.get(), end);
408 let (l, m) = self.man.split(m, start);
409
410 self.man.reverse(m);
411
412 let m = self.man.merge(l, m);
413 let root = self.man.merge(m, r);
414 self.root.set(root);
415 }
416
417 pub fn fold(&self, Range { start, end }: Range<usize>) -> M::Element {
419 let (m, r) = self.man.split(self.root.get(), end);
420 let (l, m) = self.man.split(m, start);
421
422 let ret = self.man.get_sum(m);
423
424 let m = self.man.merge(l, m);
425 let root = self.man.merge(m, r);
426 self.root.set(root);
427
428 ret
429 }
430
431 pub fn push_first(&mut self, value: M::Element) {
433 let left = Self::singleton(self.man.monoid.clone(), value);
434 self.merge_left(left);
435 }
436 pub fn push_last(&mut self, value: M::Element) {
438 let right = Self::singleton(self.man.monoid.clone(), value);
439 self.merge_right(right);
440 }
441 pub fn pop_first(&mut self) -> Option<M::Element> {
443 self.remove(0)
444 }
445 pub fn pop_last(&mut self) -> Option<M::Element> {
447 self.remove(self.len().checked_sub(1)?)
448 }
449
450 pub fn for_each(&self, mut f: impl FnMut(&M::Element)) {
452 self.man.traverse(self.root.get(), &mut f);
453 }
454}
455
456impl<M: Monoid> std::ops::Drop for SplayTree<M> {
457 fn drop(&mut self) {
458 self.man.clear(self.root.get());
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use crate::algebra::sum::*;
465 use my_testtools::rand_range;
466
467 use rand::Rng;
468
469 use super::*;
470
471 #[test]
472 fn test_empty() {
473 let monoid = Sum::<u64>::new();
474 let mut s = SplayTree::new(monoid);
475
476 assert!(s.pop_first().is_none());
477 assert!(s.pop_last().is_none());
478 }
479
480 #[test]
481 fn test() {
482 let t = 100;
483
484 let mut rng = rand::thread_rng();
485
486 let m = Sum::<u64>::new();
487 let mut a = vec![];
488 let mut st = SplayTree::new(m);
489
490 for _ in 0..t {
491 assert_eq!(a.len(), st.len());
492 let n = a.len();
493
494 let i = rng.gen_range(0..=n);
495 let x = rng.gen::<u32>() as u64;
496
497 a.insert(i, x);
498 st.insert(i, x);
499
500 assert_eq!(a.len(), st.len());
501 let n = a.len();
502
503 let Range { start: l, end: r } = rand_range(&mut rng, 0..n);
504 assert_eq!(a[l..r].iter().cloned().fold_m(&m), st.fold(l..r));
505 }
506 }
507}