1use std::cell::Cell;
7use std::cmp::Ordering;
8use std::ptr;
9
10struct Node<K, V> {
11 key: K,
12 value: V,
13 size: usize,
14 lc: *mut Self,
15 rc: *mut Self,
16 par: *mut Self,
17}
18
19impl<K, V> Node<K, V> {
20 fn new(key: K, value: V) -> Self {
21 Self {
22 key,
23 value,
24 size: 1,
25 lc: ptr::null_mut(),
26 rc: ptr::null_mut(),
27 par: ptr::null_mut(),
28 }
29 }
30
31 fn set_value(this: *mut Self, mut value: V) -> V {
32 assert!(!this.is_null());
33 std::mem::swap(unsafe { &mut (*this).value }, &mut value);
34 value
35 }
36
37 fn rotate(this: *mut Self) {
38 let p = Self::get_par(this).unwrap();
39 let pp = Self::get_par(p).unwrap();
40
41 if Self::left_of(p).unwrap() == this {
42 let c = Self::right_of(this).unwrap();
43 Self::set_left(p, c);
44 Self::set_right(this, p);
45 } else {
46 let c = Self::left_of(this).unwrap();
47 Self::set_right(p, c);
48 Self::set_left(this, p);
49 }
50
51 unsafe {
52 if !pp.is_null() {
53 if (*pp).lc == p {
54 (*pp).lc = this;
55 }
56 if (*pp).rc == p {
57 (*pp).rc = this;
58 }
59 }
60
61 assert!(!this.is_null());
62 (*this).par = pp;
63 }
64
65 Self::update(p);
66 Self::update(this);
67 }
68
69 fn status(this: *mut Self) -> i32 {
70 let par = Self::get_par(this).unwrap();
71
72 if par.is_null() {
73 return 0;
74 }
75 if unsafe { (*par).lc } == this {
76 return 1;
77 }
78 if unsafe { (*par).rc } == this {
79 return -1;
80 }
81
82 unreachable!()
83 }
84
85 fn pushdown(this: *mut Self) {
86 if !this.is_null() {
87 Self::update(this);
88 }
89 }
90
91 fn update(this: *mut Self) {
92 assert!(!this.is_null());
93 unsafe {
94 (*this).size = 1 + Self::size_of((*this).lc) + Self::size_of((*this).rc);
95 }
96 }
97
98 fn splay(this: *mut Self) {
99 while Self::status(this) != 0 {
100 let par = Self::get_par(this).unwrap();
101
102 if Self::status(par) == 0 {
103 Self::rotate(this);
104 } else if Self::status(this) == Self::status(par) {
105 Self::rotate(par);
106 Self::rotate(this);
107 } else {
108 Self::rotate(this);
109 Self::rotate(this);
110 }
111 }
112 }
113
114 fn get(root: *mut Self, mut index: usize) -> *mut Self {
115 if root.is_null() {
116 return root;
117 }
118
119 let mut cur = root;
120
121 loop {
122 Self::pushdown(cur);
123
124 let left = Self::left_of(cur).unwrap();
125 let lsize = Self::size_of(left);
126
127 match index.cmp(&lsize) {
128 Ordering::Less => {
129 cur = left;
130 }
131 Ordering::Equal => {
132 Self::splay(cur);
133 return cur;
134 }
135 Ordering::Greater => {
136 cur = Self::right_of(cur).unwrap();
137 index -= lsize + 1;
138 }
139 }
140 }
141 }
142
143 fn merge(left: *mut Self, right: *mut Self) -> *mut Self {
144 if left.is_null() {
145 return right;
146 }
147 if right.is_null() {
148 return left;
149 }
150
151 let cur = Self::get(left, Self::size_of(left) - 1);
152
153 Self::set_right(cur, right);
154 Self::update(right);
155 Self::update(cur);
156
157 cur
158 }
159
160 fn split(root: *mut Self, index: usize) -> (*mut Self, *mut Self) {
161 if root.is_null() {
162 return (ptr::null_mut(), ptr::null_mut());
163 }
164 if index >= Self::size_of(root) {
165 return (root, ptr::null_mut());
166 }
167
168 let cur = Self::get(root, index);
169 let left = Self::left_of(cur).unwrap();
170
171 if !left.is_null() {
172 unsafe {
173 (*left).par = ptr::null_mut();
174 }
175 Self::update(left);
176 }
177 assert!(!cur.is_null());
178 unsafe {
179 (*cur).lc = ptr::null_mut();
180 }
181 Self::update(cur);
182
183 (left, cur)
184 }
185
186 fn traverse(cur: *mut Self, f: &mut impl FnMut(&K, &mut V)) {
187 if !cur.is_null() {
188 Self::pushdown(cur);
189 Self::traverse(Self::left_of(cur).unwrap(), f);
190 f(unsafe { &(*cur).key }, unsafe { &mut (*cur).value });
191 Self::traverse(Self::right_of(cur).unwrap(), f);
192 }
193 }
194
195 fn set_left(this: *mut Self, left: *mut Self) {
196 assert!(!this.is_null());
197 unsafe {
198 (*this).lc = left;
199 if !left.is_null() {
200 (*left).par = this;
201 }
202 }
203 }
204
205 fn set_right(this: *mut Self, right: *mut Self) {
206 assert!(!this.is_null());
207 unsafe {
208 (*this).rc = right;
209 if !right.is_null() {
210 (*right).par = this;
211 }
212 }
213 }
214
215 fn set_par(this: *mut Self, par: *mut Self) {
216 if !this.is_null() {
217 unsafe {
218 (*this).par = par;
219 }
220 }
221 }
222
223 fn size_of(this: *mut Self) -> usize {
224 if this.is_null() {
225 0
226 } else {
227 unsafe { (*this).size }
228 }
229 }
230
231 fn left_of(this: *mut Self) -> Option<*mut Self> {
232 (!this.is_null()).then_some(unsafe { (*this).lc })
233 }
234
235 fn right_of(this: *mut Self) -> Option<*mut Self> {
236 (!this.is_null()).then_some(unsafe { (*this).rc })
237 }
238
239 fn get_par(this: *mut Self) -> Option<*mut Self> {
240 (!this.is_null()).then_some(unsafe { (*this).par })
241 }
242
243 fn clear(this: *mut Self) {
244 if !this.is_null() {
245 let lc = Self::left_of(this).unwrap();
246 let rc = Self::right_of(this).unwrap();
247
248 let _ = unsafe { Box::from_raw(this) };
249
250 Self::clear(lc);
251 Self::clear(rc);
252 }
253 }
254
255 fn key_of<'a>(this: *mut Self) -> Option<&'a K> {
256 (!this.is_null()).then(|| unsafe { &(*this).key })
257 }
258}
259
260impl<K: Ord, V> Node<K, V> {
261 fn binary_search(this: *mut Self, key: &K) -> (*mut Self, Result<usize, usize>) {
262 let mut cur = this;
263 let mut index = 0;
264 let mut prev = ptr::null_mut();
265
266 while !cur.is_null() {
267 let left = Self::left_of(cur).unwrap();
268 let c = Self::size_of(left);
269 prev = cur;
270
271 match Self::key_of(cur).unwrap().cmp(key) {
272 Ordering::Equal => {
273 Self::splay(cur);
274 return (cur, Ok(index + c));
275 }
276 Ordering::Greater => {
277 cur = left;
278 }
279 Ordering::Less => {
280 cur = Self::right_of(cur).unwrap();
281 index += c + 1;
282 }
283 }
284 }
285
286 if !prev.is_null() {
287 Self::splay(prev);
288 }
289 (prev, Err(index))
290 }
291}
292
293pub struct OrderedMap<K, V> {
295 root: Cell<*mut Node<K, V>>,
296}
297
298impl<K: Ord, V> OrderedMap<K, V> {
299 pub fn new() -> Self {
301 Self {
302 root: Cell::new(ptr::null_mut()),
303 }
304 }
305
306 pub fn len(&self) -> usize {
308 Node::size_of(self.root.get())
309 }
310
311 pub fn is_empty(&self) -> bool {
313 self.root.get().is_null()
314 }
315
316 pub fn binary_search(&self, key: &K) -> Result<usize, usize> {
319 let (root, index) = Node::binary_search(self.root.get(), key);
320 self.root.set(root);
321 index
322 }
323
324 pub fn max_le(&self, key: &K) -> Option<(&K, &V)> {
326 match self.binary_search(key) {
327 Ok(i) => self.get_by_index(i),
328 Err(i) => {
329 if i > 0 {
330 self.get_by_index(i - 1)
331 } else {
332 None
333 }
334 }
335 }
336 }
337
338 pub fn min_ge(&self, key: &K) -> Option<(&K, &V)> {
340 match self.binary_search(key) {
341 Ok(i) | Err(i) => self.get_by_index(i),
342 }
343 }
344
345 pub fn contains(&self, key: &K) -> bool {
347 self.binary_search(key).is_ok()
348 }
349
350 pub fn insert(&mut self, key: K, value: V) -> Option<V> {
353 let (new_root, index) = Node::binary_search(self.root.get(), &key);
354 match index {
355 Ok(_) => {
356 let old = Node::set_value(new_root, value);
357 self.root.set(new_root);
358 Some(old)
359 }
360 Err(i) => {
361 let (l, r) = Node::split(new_root, i);
362 let node = Box::into_raw(Box::new(Node::new(key, value)));
363 let root = Node::merge(l, Node::merge(node, r));
364 self.root.set(root);
365 None
366 }
367 }
368 }
369
370 pub fn get(&self, key: &K) -> Option<&V> {
372 let k = self.binary_search(key).ok()?;
373 self.get_by_index(k).map(|(_, v)| v)
374 }
375
376 pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
378 let k = self.binary_search(key).ok()?;
379 self.get_value_mut_by_index(k)
380 }
381
382 pub fn remove(&mut self, key: &K) -> Option<V> {
384 let (m, index) = Node::binary_search(self.root.get(), key);
385
386 if index.is_err() {
387 self.root.set(m);
388 return None;
389 }
390
391 assert!(!m.is_null());
392 let left = Node::left_of(m).unwrap();
393 Node::set_par(left, ptr::null_mut());
394 let right = Node::right_of(m).unwrap();
395 Node::set_par(right, ptr::null_mut());
396
397 self.root.set(Node::merge(left, right));
398
399 (!m.is_null()).then(|| unsafe {
400 let m = Box::from_raw(m);
401 m.value
402 })
403 }
404
405 pub fn get_by_index(&self, i: usize) -> Option<(&K, &V)> {
407 if i >= self.len() {
408 None
409 } else {
410 let t = Node::get(self.root.get(), i);
411 self.root.set(t);
412 (!t.is_null()).then(|| unsafe { (&(*t).key, &(*t).value) })
413 }
414 }
415
416 pub fn get_key_by_index(&self, i: usize) -> Option<&K> {
418 self.get_by_index(i).map(|(k, _)| k)
419 }
420
421 pub fn get_value_by_index(&self, i: usize) -> Option<&V> {
423 self.get_by_index(i).map(|(_, v)| v)
424 }
425
426 pub fn get_value_mut_by_index(&mut self, i: usize) -> Option<&mut V> {
428 if i >= self.len() {
429 None
430 } else {
431 let t = Node::get(self.root.get(), i);
432 self.root.set(t);
433 (!t.is_null()).then(|| unsafe { &mut (*t).value })
434 }
435 }
436
437 pub fn remove_by_index(&mut self, i: usize) -> Option<(K, V)> {
439 let (l, r) = Node::split(self.root.get(), i);
440 let (m, r) = Node::split(r, 1);
441 self.root.set(Node::merge(l, r));
442
443 (!m.is_null()).then(|| unsafe {
444 let m = Box::from_raw(m);
445 let node = *m;
446 (node.key, node.value)
447 })
448 }
449
450 pub fn for_each(&self, mut f: impl FnMut(&K, &mut V)) {
452 Node::traverse(self.root.get(), &mut f);
453 }
454
455 }
462
463impl<K, V> std::ops::Drop for OrderedMap<K, V> {
464 fn drop(&mut self) {
465 Node::clear(self.root.get());
466 }
467}
468
469impl<K: Ord, V> Default for OrderedMap<K, V> {
470 fn default() -> Self {
471 Self::new()
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use rand::Rng;
478 use std::collections::BTreeMap;
479
480 use super::*;
481
482 #[test]
483 fn test() {
484 let mut rng = rand::thread_rng();
485
486 let mut map = OrderedMap::<u32, u32>::new();
487 let mut ans = BTreeMap::<u32, u32>::new();
488
489 let q = 10000;
490
491 for _ in 0..q {
492 let x: u32 = rng.gen_range(0..1000);
493 let y: u32 = rng.gen();
494
495 assert_eq!(map.insert(x, y), ans.insert(x, y));
496
497 let x = rng.gen_range(0..1000);
498
499 assert_eq!(map.remove(&x), ans.remove(&x));
500
501 let x = rng.gen_range(0..1000);
502
503 assert_eq!(map.get(&x), ans.get(&x));
504
505 assert_eq!(map.len(), ans.len());
506 }
507 }
508}