1use std::cell::Cell;
7use std::cmp::Ordering;
8use std::ptr;
9
10struct Node<K, V> {
11 key: K,
12 value: V,
13 size: usize,
14 lc: *mut Self,
15 rc: *mut Self,
16 par: *mut Self,
17}
18
19impl<K, V> Node<K, V> {
20 fn new(key: K, value: V) -> Self {
21 Self {
22 key,
23 value,
24 size: 1,
25 lc: ptr::null_mut(),
26 rc: ptr::null_mut(),
27 par: ptr::null_mut(),
28 }
29 }
30
31 fn set_value(this: *mut Self, mut value: V) -> V {
32 assert!(!this.is_null());
33 std::mem::swap(unsafe { &mut (*this).value }, &mut value);
34 value
35 }
36
37 fn rotate(this: *mut Self) {
38 let p = Self::get_par(this).unwrap();
39 let pp = Self::get_par(p).unwrap();
40
41 if Self::left_of(p).unwrap() == this {
42 let c = Self::right_of(this).unwrap();
43 Self::set_left(p, c);
44 Self::set_right(this, p);
45 } else {
46 let c = Self::left_of(this).unwrap();
47 Self::set_right(p, c);
48 Self::set_left(this, p);
49 }
50
51 unsafe {
52 if !pp.is_null() {
53 if (*pp).lc == p {
54 (*pp).lc = this;
55 }
56 if (*pp).rc == p {
57 (*pp).rc = this;
58 }
59 }
60
61 assert!(!this.is_null());
62 (*this).par = pp;
63 }
64
65 Self::update(p);
66 Self::update(this);
67 }
68
69 fn status(this: *mut Self) -> i32 {
70 let par = Self::get_par(this).unwrap();
71
72 if par.is_null() {
73 return 0;
74 }
75 if unsafe { (*par).lc } == this {
76 return 1;
77 }
78 if unsafe { (*par).rc } == this {
79 return -1;
80 }
81
82 unreachable!()
83 }
84
85 fn pushdown(this: *mut Self) {
86 if !this.is_null() {
87 Self::update(this);
88 }
89 }
90
91 fn update(this: *mut Self) {
92 assert!(!this.is_null());
93 unsafe {
94 (*this).size = 1 + Self::size_of((*this).lc) + Self::size_of((*this).rc);
95 }
96 }
97
98 fn splay(this: *mut Self) {
99 while Self::status(this) != 0 {
100 let par = Self::get_par(this).unwrap();
101
102 if Self::status(par) == 0 {
103 Self::rotate(this);
104 } else if Self::status(this) == Self::status(par) {
105 Self::rotate(par);
106 Self::rotate(this);
107 } else {
108 Self::rotate(this);
109 Self::rotate(this);
110 }
111 }
112 }
113
114 fn get(root: *mut Self, mut index: usize) -> *mut Self {
115 if root.is_null() {
116 return root;
117 }
118
119 let mut cur = root;
120
121 loop {
122 Self::pushdown(cur);
123
124 let left = Self::left_of(cur).unwrap();
125 let lsize = Self::size_of(left);
126
127 match index.cmp(&lsize) {
128 Ordering::Less => {
129 cur = left;
130 }
131 Ordering::Equal => {
132 Self::splay(cur);
133 return cur;
134 }
135 Ordering::Greater => {
136 cur = Self::right_of(cur).unwrap();
137 index -= lsize + 1;
138 }
139 }
140 }
141 }
142
143 fn merge(left: *mut Self, right: *mut Self) -> *mut Self {
144 if left.is_null() {
145 return right;
146 }
147 if right.is_null() {
148 return left;
149 }
150
151 let cur = Self::get(left, Self::size_of(left) - 1);
152
153 Self::set_right(cur, right);
154 Self::update(right);
155 Self::update(cur);
156
157 cur
158 }
159
160 fn split(root: *mut Self, index: usize) -> (*mut Self, *mut Self) {
161 if root.is_null() {
162 return (ptr::null_mut(), ptr::null_mut());
163 }
164 if index >= Self::size_of(root) {
165 return (root, ptr::null_mut());
166 }
167
168 let cur = Self::get(root, index);
169 let left = Self::left_of(cur).unwrap();
170
171 if !left.is_null() {
172 unsafe {
173 (*left).par = ptr::null_mut();
174 }
175 Self::update(left);
176 }
177 assert!(!cur.is_null());
178 unsafe {
179 (*cur).lc = ptr::null_mut();
180 }
181 Self::update(cur);
182
183 (left, cur)
184 }
185
186 fn traverse(cur: *mut Self, f: &mut impl FnMut(&K, &mut V)) {
187 if !cur.is_null() {
188 Self::pushdown(cur);
189 Self::traverse(Self::left_of(cur).unwrap(), f);
190 f(unsafe { &(*cur).key }, unsafe { &mut (*cur).value });
191 Self::traverse(Self::right_of(cur).unwrap(), f);
192 }
193 }
194
195 fn set_left(this: *mut Self, left: *mut Self) {
196 assert!(!this.is_null());
197 unsafe {
198 (*this).lc = left;
199 if !left.is_null() {
200 (*left).par = this;
201 }
202 }
203 }
204
205 fn set_right(this: *mut Self, right: *mut Self) {
206 assert!(!this.is_null());
207 unsafe {
208 (*this).rc = right;
209 if !right.is_null() {
210 (*right).par = this;
211 }
212 }
213 }
214
215 fn size_of(this: *mut Self) -> usize {
216 if this.is_null() {
217 0
218 } else {
219 unsafe { (*this).size }
220 }
221 }
222
223 fn left_of(this: *mut Self) -> Option<*mut Self> {
224 (!this.is_null()).then_some(unsafe { (*this).lc })
225 }
226
227 fn right_of(this: *mut Self) -> Option<*mut Self> {
228 (!this.is_null()).then_some(unsafe { (*this).rc })
229 }
230
231 fn get_par(this: *mut Self) -> Option<*mut Self> {
232 (!this.is_null()).then_some(unsafe { (*this).par })
233 }
234
235 fn clear(this: *mut Self) {
236 if !this.is_null() {
237 let lc = Self::left_of(this).unwrap();
238 let rc = Self::right_of(this).unwrap();
239
240 let _ = unsafe { Box::from_raw(this) };
241
242 Self::clear(lc);
243 Self::clear(rc);
244 }
245 }
246
247 fn key_of<'a>(this: *mut Self) -> Option<&'a K> {
248 (!this.is_null()).then(|| unsafe { &(*this).key })
249 }
250}
251
252impl<K: Ord, V> Node<K, V> {
253 fn binary_search(this: *mut Self, key: &K) -> Result<usize, usize> {
254 if this.is_null() {
255 Err(0)
256 } else {
257 let left = Self::left_of(this).unwrap();
258 let right = Self::right_of(this).unwrap();
259 let c = Self::size_of(left);
260 match Self::key_of(this).unwrap().cmp(key) {
261 Ordering::Equal => Ok(c),
262 Ordering::Greater => Self::binary_search(left, key),
263 Ordering::Less => Self::binary_search(right, key)
264 .map(|a| a + c + 1)
265 .map_err(|a| a + c + 1),
266 }
267 }
268 }
269}
270
271pub struct OrderedMap<K, V> {
273 root: Cell<*mut Node<K, V>>,
274}
275
276impl<K: Ord, V> OrderedMap<K, V> {
277 pub fn new() -> Self {
279 Self {
280 root: Cell::new(ptr::null_mut()),
281 }
282 }
283
284 pub fn len(&self) -> usize {
286 Node::size_of(self.root.get())
287 }
288
289 pub fn is_empty(&self) -> bool {
291 self.root.get().is_null()
292 }
293
294 pub fn binary_search(&self, key: &K) -> Result<usize, usize> {
297 Node::binary_search(self.root.get(), key)
298 }
299
300 pub fn max_le(&self, key: &K) -> Option<(&K, &V)> {
302 match self.binary_search(key) {
303 Ok(i) => self.get_by_index(i),
304 Err(i) => {
305 if i > 0 {
306 self.get_by_index(i - 1)
307 } else {
308 None
309 }
310 }
311 }
312 }
313
314 pub fn min_ge(&self, key: &K) -> Option<(&K, &V)> {
316 match self.binary_search(key) {
317 Ok(i) | Err(i) => self.get_by_index(i),
318 }
319 }
320
321 pub fn contains(&self, key: &K) -> bool {
323 Node::binary_search(self.root.get(), key).is_ok()
324 }
325
326 pub fn insert(&mut self, key: K, value: V) -> Option<V> {
329 match Node::binary_search(self.root.get(), &key) {
330 Ok(i) => {
331 let (l, r) = Node::split(self.root.get(), i);
332 let (m, r) = Node::split(r, 1);
333 let old = Node::set_value(m, value);
334
335 let r = Node::merge(m, r);
336 let root = Node::merge(l, r);
337 self.root.set(root);
338 Some(old)
339 }
340 Err(i) => {
341 let (l, r) = Node::split(self.root.get(), i);
342 let node = Box::into_raw(Box::new(Node::new(key, value)));
343 let root = Node::merge(l, Node::merge(node, r));
344 self.root.set(root);
345 None
346 }
347 }
348 }
349
350 pub fn get(&self, key: &K) -> Option<&V> {
352 let k = Node::binary_search(self.root.get(), key).ok()?;
353 self.get_by_index(k).map(|(_, v)| v)
354 }
355
356 pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
358 let k = Node::binary_search(self.root.get(), key).ok()?;
359 self.get_value_mut_by_index(k)
360 }
361
362 pub fn remove(&mut self, key: &K) -> Option<V> {
364 let i = Node::binary_search(self.root.get(), key).ok()?;
365 self.remove_by_index(i).map(|(_, v)| v)
366 }
367
368 pub fn get_by_index(&self, i: usize) -> Option<(&K, &V)> {
370 if i >= self.len() {
371 None
372 } else {
373 let t = Node::get(self.root.get(), i);
374 self.root.set(t);
375 (!t.is_null()).then(|| unsafe { (&(*t).key, &(*t).value) })
376 }
377 }
378
379 pub fn get_key_by_index(&self, i: usize) -> Option<&K> {
381 self.get_by_index(i).map(|(k, _)| k)
382 }
383
384 pub fn get_value_by_index(&self, i: usize) -> Option<&V> {
386 self.get_by_index(i).map(|(_, v)| v)
387 }
388
389 pub fn get_value_mut_by_index(&mut self, i: usize) -> Option<&mut V> {
391 if i >= self.len() {
392 None
393 } else {
394 let t = Node::get(self.root.get(), i);
395 self.root.set(t);
396 (!t.is_null()).then(|| unsafe { &mut (*t).value })
397 }
398 }
399
400 pub fn remove_by_index(&mut self, i: usize) -> Option<(K, V)> {
402 let (l, r) = Node::split(self.root.get(), i);
403 let (m, r) = Node::split(r, 1);
404 self.root.set(Node::merge(l, r));
405
406 (!m.is_null()).then(|| unsafe {
407 let m = Box::from_raw(m);
408 let node = *m;
409 (node.key, node.value)
410 })
411 }
412
413 pub fn for_each(&self, mut f: impl FnMut(&K, &mut V)) {
415 Node::traverse(self.root.get(), &mut f);
416 }
417
418 }
425
426impl<K, V> std::ops::Drop for OrderedMap<K, V> {
427 fn drop(&mut self) {
428 Node::clear(self.root.get());
429 }
430}
431
432impl<K: Ord, V> Default for OrderedMap<K, V> {
433 fn default() -> Self {
434 Self::new()
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use rand::Rng;
441 use std::collections::BTreeMap;
442
443 use super::*;
444
445 #[test]
446 fn test() {
447 let mut rng = rand::thread_rng();
448
449 let mut map = OrderedMap::<u32, u32>::new();
450 let mut ans = BTreeMap::<u32, u32>::new();
451
452 let q = 10000;
453
454 for _ in 0..q {
455 let x: u32 = rng.gen_range(0..1000);
456 let y: u32 = rng.gen();
457
458 assert_eq!(map.insert(x, y), ans.insert(x, y));
459
460 let x = rng.gen_range(0..1000);
461
462 assert_eq!(map.remove(&x), ans.remove(&x));
463
464 let x = rng.gen_range(0..1000);
465
466 assert_eq!(map.get(&x), ans.get(&x));
467
468 assert_eq!(map.len(), ans.len());
469 }
470 }
471}