1use std::cell::Cell;
11use std::cmp::Ordering;
12use std::ops::Range;
13use std::ptr;
14
15use crate::algebra::traits::Monoid;
16
17struct Node<M: Monoid> {
18 value: M,
19 sum: M,
20 size: usize,
21 rev: bool,
22 lc: *mut Node<M>,
23 rc: *mut Node<M>,
24 par: *mut Node<M>,
25}
26
27impl<M: Monoid + Clone> Node<M> {
28 fn new(value: M) -> Self {
29 Self {
30 value,
31 sum: M::id(),
32 size: 1,
33 rev: false,
34 lc: ptr::null_mut(),
35 rc: ptr::null_mut(),
36 par: ptr::null_mut(),
37 }
38 }
39
40 fn get_sum(this: *mut Self) -> M {
41 assert!(!this.is_null());
42 unsafe { (*this).sum.clone() }
43 }
44
45 fn set_value(this: *mut Self, value: M) {
46 assert!(!this.is_null());
47 unsafe {
48 (*this).value = value;
49 }
50 }
51
52 fn rotate(this: *mut Self) {
53 let p = Self::get_par(this).unwrap();
54 let pp = Self::get_par(p).unwrap();
55
56 if Self::left_of(p).unwrap() == this {
57 let c = Self::right_of(this).unwrap();
58 Self::set_left(p, c);
59 Self::set_right(this, p);
60 } else {
61 let c = Self::left_of(this).unwrap();
62 Self::set_right(p, c);
63 Self::set_left(this, p);
64 }
65
66 unsafe {
67 if !pp.is_null() {
68 if (*pp).lc == p {
69 (*pp).lc = this;
70 }
71 if (*pp).rc == p {
72 (*pp).rc = this;
73 }
74 }
75
76 assert!(!this.is_null());
77 (*this).par = pp;
78 }
79
80 Self::update(p);
81 Self::update(this);
82 }
83
84 fn status(this: *mut Self) -> i32 {
85 let par = Self::get_par(this).unwrap();
86
87 if par.is_null() {
88 return 0;
89 }
90 if unsafe { (*par).lc } == this {
91 return 1;
92 }
93 if unsafe { (*par).rc } == this {
94 return -1;
95 }
96
97 unreachable!()
98 }
99
100 fn reverse(this: *mut Self) {
101 if !this.is_null() {
102 unsafe {
103 (*this).rev ^= true;
104 }
105 }
106 }
107
108 fn pushdown(this: *mut Self) {
109 if !this.is_null() {
110 unsafe {
111 if (*this).rev {
112 std::mem::swap(&mut (*this).lc, &mut (*this).rc);
113 Self::reverse((*this).lc);
114 Self::reverse((*this).rc);
115 (*this).rev = false;
116 }
117 }
118 Self::update(this);
119 }
120 }
121
122 fn update(this: *mut Self) {
123 assert!(!this.is_null());
124 unsafe {
125 (*this).size = 1 + Self::size_of((*this).lc) + Self::size_of((*this).rc);
126
127 (*this).sum = (*this).value.clone();
128 if !(*this).lc.is_null() {
129 (*this).sum = M::op(Self::get_sum(this), Self::get_sum((*this).lc));
130 }
131 if !(*this).rc.is_null() {
132 (*this).sum = M::op(Self::get_sum(this), Self::get_sum((*this).rc));
133 }
134 }
135 }
136
137 fn splay(this: *mut Self) {
138 while Self::status(this) != 0 {
139 let par = Self::get_par(this).unwrap();
140
141 if Self::status(par) == 0 {
142 Self::rotate(this);
143 } else if Self::status(this) == Self::status(par) {
144 Self::rotate(par);
145 Self::rotate(this);
146 } else {
147 Self::rotate(this);
148 Self::rotate(this);
149 }
150 }
151 }
152
153 fn get(root: *mut Self, mut index: usize) -> *mut Self {
154 if root.is_null() {
155 return root;
156 }
157
158 let mut cur = root;
159
160 loop {
161 Self::pushdown(cur);
162
163 let left = Self::left_of(cur).unwrap();
164 let lsize = Self::size_of(left);
165
166 match index.cmp(&lsize) {
167 Ordering::Less => {
168 cur = left;
169 }
170 Ordering::Equal => {
171 Self::splay(cur);
172 return cur;
173 }
174 Ordering::Greater => {
175 cur = Self::right_of(cur).unwrap();
176 index -= lsize + 1;
177 }
178 }
179 }
180 }
181
182 fn merge(left: *mut Self, right: *mut Self) -> *mut Self {
183 if left.is_null() {
184 return right;
185 }
186 if right.is_null() {
187 return left;
188 }
189
190 let cur = Self::get(left, Self::size_of(left) - 1);
191
192 Self::set_right(cur, right);
193 Self::update(right);
194 Self::update(cur);
195
196 cur
197 }
198
199 fn split(root: *mut Self, index: usize) -> (*mut Self, *mut Self) {
200 if root.is_null() {
201 return (ptr::null_mut(), ptr::null_mut());
202 }
203 if index >= Self::size_of(root) {
204 return (root, ptr::null_mut());
205 }
206
207 let cur = Self::get(root, index);
208 let left = Self::left_of(cur).unwrap();
209
210 if !left.is_null() {
211 unsafe {
212 (*left).par = ptr::null_mut();
213 }
214 Self::update(left);
215 }
216 assert!(!cur.is_null());
217 unsafe {
218 (*cur).lc = ptr::null_mut();
219 }
220 Self::update(cur);
221
222 (left, cur)
223 }
224
225 fn traverse(cur: *mut Self, f: &mut impl FnMut(&M)) {
226 if !cur.is_null() {
227 Self::pushdown(cur);
228 Self::traverse(Self::left_of(cur).unwrap(), f);
229 f(unsafe { &(*cur).value });
230 Self::traverse(Self::right_of(cur).unwrap(), f);
231 }
232 }
233}
234
235impl<M: Monoid> Node<M> {
236 fn set_left(this: *mut Self, left: *mut Self) {
237 assert!(!this.is_null());
238 unsafe {
239 (*this).lc = left;
240 if !left.is_null() {
241 (*left).par = this;
242 }
243 }
244 }
245
246 fn set_right(this: *mut Self, right: *mut Self) {
247 assert!(!this.is_null());
248 unsafe {
249 (*this).rc = right;
250 if !right.is_null() {
251 (*right).par = this;
252 }
253 }
254 }
255
256 fn size_of(this: *mut Self) -> usize {
257 if this.is_null() {
258 0
259 } else {
260 unsafe { (*this).size }
261 }
262 }
263
264 fn left_of(this: *mut Self) -> Option<*mut Self> {
265 (!this.is_null()).then_some(unsafe { (*this).lc })
266 }
267
268 fn right_of(this: *mut Self) -> Option<*mut Self> {
269 (!this.is_null()).then_some(unsafe { (*this).rc })
270 }
271
272 fn get_par(this: *mut Self) -> Option<*mut Self> {
273 (!this.is_null()).then_some(unsafe { (*this).par })
274 }
275
276 fn clear(this: *mut Self) {
277 if !this.is_null() {
278 let lc = Self::left_of(this).unwrap();
279 let rc = Self::right_of(this).unwrap();
280
281 let _ = unsafe { Box::from_raw(this) };
282
283 Self::clear(lc);
284 Self::clear(rc);
285 }
286 }
287}
288
289pub struct SplayTree<M: Monoid> {
291 root: Cell<*mut Node<M>>,
292}
293
294impl<M: Monoid + Clone> Default for SplayTree<M> {
295 fn default() -> Self {
296 Self::new()
297 }
298}
299
300impl<M: Monoid + Clone> SplayTree<M> {
301 pub fn new() -> Self {
303 Self {
304 root: Cell::new(ptr::null_mut()),
305 }
306 }
307
308 pub fn singleton(value: M) -> Self {
310 let root = Box::new(Node::new(value));
311
312 Self {
313 root: Cell::new(Box::into_raw(root)),
314 }
315 }
316
317 pub fn len(&self) -> usize {
319 Node::size_of(self.root.get())
320 }
321
322 pub fn is_empty(&self) -> bool {
324 self.root.get().is_null()
325 }
326
327 pub fn get(&self, index: usize) -> Option<&M> {
329 self.root.set(Node::get(self.root.get(), index));
330 let node = self.root.get();
331
332 if node.is_null() {
333 None
334 } else {
335 unsafe { Some(&(*node).value) }
336 }
337 }
338
339 pub fn set(&mut self, index: usize, value: M) {
341 let root = Node::get(self.root.get(), index);
342 Node::set_value(root, value);
343 Node::update(root);
344 self.root.set(root);
345 }
346
347 pub fn merge_right(&mut self, right: Self) {
349 let root = Node::merge(self.root.get(), right.root.get());
350 right.root.set(ptr::null_mut());
351 self.root.set(root);
352 }
353
354 pub fn merge_left(&mut self, left: Self) {
356 let root = Node::merge(left.root.get(), self.root.get());
357 left.root.set(ptr::null_mut());
358 self.root.set(root);
359 }
360
361 pub fn split(self, index: usize) -> (Self, Self) {
363 let (l, r) = Node::split(self.root.get(), index);
364 self.root.set(ptr::null_mut());
365 (Self { root: Cell::new(l) }, Self { root: Cell::new(r) })
366 }
367
368 pub fn insert(&mut self, index: usize, value: M) {
370 let (l, r) = Node::split(self.root.get(), index);
371 let node = Box::into_raw(Box::new(Node::new(value)));
372 let root = Node::merge(l, Node::merge(node, r));
373 self.root.set(root);
374 }
375
376 pub fn remove(&mut self, index: usize) -> Option<M> {
378 let (l, r) = Node::split(self.root.get(), index);
379 let (m, r) = Node::split(r, 1);
380
381 if m.is_null() {
382 return None;
383 }
384
385 let value = unsafe {
386 let m = Box::from_raw(m);
387 m.value
388 };
389
390 self.root.set(Node::merge(l, r));
391
392 Some(value)
393 }
394
395 pub fn reverse(&mut self, Range { start, end }: Range<usize>) {
397 let (m, r) = Node::split(self.root.get(), end);
398 let (l, m) = Node::split(m, start);
399
400 Node::reverse(m);
401
402 let m = Node::merge(l, m);
403 let root = Node::merge(m, r);
404 self.root.set(root);
405 }
406
407 pub fn fold(&self, Range { start, end }: Range<usize>) -> M {
409 let (m, r) = Node::split(self.root.get(), end);
410 let (l, m) = Node::split(m, start);
411
412 let ret = if m.is_null() {
413 M::id()
414 } else {
415 Node::get_sum(m)
416 };
417
418 let m = Node::merge(l, m);
419 let root = Node::merge(m, r);
420 self.root.set(root);
421
422 ret
423 }
424
425 pub fn push_first(&mut self, value: M) {
427 let left = Self::singleton(value);
428 self.merge_left(left);
429 }
430 pub fn push_last(&mut self, value: M) {
432 let right = Self::singleton(value);
433 self.merge_right(right);
434 }
435 pub fn pop_first(&mut self) -> Option<M> {
437 self.remove(0)
438 }
439 pub fn pop_last(&mut self) -> Option<M> {
441 if self.is_empty() {
442 None
443 } else {
444 self.remove(self.len() - 1)
445 }
446 }
447
448 pub fn for_each(&self, mut f: impl FnMut(&M)) {
450 Node::traverse(self.root.get(), &mut f);
451 }
452}
453
454impl<M: Monoid> std::ops::Drop for SplayTree<M> {
455 fn drop(&mut self) {
456 Node::clear(self.root.get());
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use crate::algebra::sum::*;
463 use my_testtools::rand_range;
464
465 use rand::Rng;
466
467 use super::*;
468
469 #[test]
470 fn test() {
471 let t = 100;
472
473 let mut rng = rand::thread_rng();
474
475 let mut a = vec![];
476 let mut st = SplayTree::<Sum<u64>>::new();
477
478 for _ in 0..t {
479 assert_eq!(a.len(), st.len());
480 let n = a.len();
481
482 let i = rng.gen_range(0..=n);
483 let x = Sum(rng.gen::<u32>() as u64);
484
485 a.insert(i, x);
486 st.insert(i, x);
487
488 assert_eq!(a.len(), st.len());
489 let n = a.len();
490
491 let Range { start: l, end: r } = rand_range(&mut rng, 0..n);
492 assert_eq!(a[l..r].iter().cloned().fold_m(), st.fold(l..r));
493 }
494 }
495}