haar_lib/ds/
ordered_map.rs

1//! 順序付き辞書
2//!
3//! # Problems
4//! - <https://judge.yosupo.jp/problem/ordered_set>
5
6use std::cell::Cell;
7use std::cmp::Ordering;
8use std::ptr;
9
10struct Node<K, V> {
11    key: K,
12    value: V,
13    size: usize,
14    lc: *mut Self,
15    rc: *mut Self,
16    par: *mut Self,
17}
18
19impl<K, V> Node<K, V> {
20    fn new(key: K, value: V) -> Self {
21        Self {
22            key,
23            value,
24            size: 1,
25            lc: ptr::null_mut(),
26            rc: ptr::null_mut(),
27            par: ptr::null_mut(),
28        }
29    }
30
31    fn set_value(this: *mut Self, mut value: V) -> V {
32        assert!(!this.is_null());
33        std::mem::swap(unsafe { &mut (*this).value }, &mut value);
34        value
35    }
36
37    fn rotate(this: *mut Self) {
38        let p = Self::get_par(this).unwrap();
39        let pp = Self::get_par(p).unwrap();
40
41        if Self::left_of(p).unwrap() == this {
42            let c = Self::right_of(this).unwrap();
43            Self::set_left(p, c);
44            Self::set_right(this, p);
45        } else {
46            let c = Self::left_of(this).unwrap();
47            Self::set_right(p, c);
48            Self::set_left(this, p);
49        }
50
51        unsafe {
52            if !pp.is_null() {
53                if (*pp).lc == p {
54                    (*pp).lc = this;
55                }
56                if (*pp).rc == p {
57                    (*pp).rc = this;
58                }
59            }
60
61            assert!(!this.is_null());
62            (*this).par = pp;
63        }
64
65        Self::update(p);
66        Self::update(this);
67    }
68
69    fn status(this: *mut Self) -> i32 {
70        let par = Self::get_par(this).unwrap();
71
72        if par.is_null() {
73            return 0;
74        }
75        if unsafe { (*par).lc } == this {
76            return 1;
77        }
78        if unsafe { (*par).rc } == this {
79            return -1;
80        }
81
82        unreachable!()
83    }
84
85    fn pushdown(this: *mut Self) {
86        if !this.is_null() {
87            Self::update(this);
88        }
89    }
90
91    fn update(this: *mut Self) {
92        assert!(!this.is_null());
93        unsafe {
94            (*this).size = 1 + Self::size_of((*this).lc) + Self::size_of((*this).rc);
95        }
96    }
97
98    fn splay(this: *mut Self) {
99        while Self::status(this) != 0 {
100            let par = Self::get_par(this).unwrap();
101
102            if Self::status(par) == 0 {
103                Self::rotate(this);
104            } else if Self::status(this) == Self::status(par) {
105                Self::rotate(par);
106                Self::rotate(this);
107            } else {
108                Self::rotate(this);
109                Self::rotate(this);
110            }
111        }
112    }
113
114    fn get(root: *mut Self, mut index: usize) -> *mut Self {
115        if root.is_null() {
116            return root;
117        }
118
119        let mut cur = root;
120
121        loop {
122            Self::pushdown(cur);
123
124            let left = Self::left_of(cur).unwrap();
125            let lsize = Self::size_of(left);
126
127            match index.cmp(&lsize) {
128                Ordering::Less => {
129                    cur = left;
130                }
131                Ordering::Equal => {
132                    Self::splay(cur);
133                    return cur;
134                }
135                Ordering::Greater => {
136                    cur = Self::right_of(cur).unwrap();
137                    index -= lsize + 1;
138                }
139            }
140        }
141    }
142
143    fn merge(left: *mut Self, right: *mut Self) -> *mut Self {
144        if left.is_null() {
145            return right;
146        }
147        if right.is_null() {
148            return left;
149        }
150
151        let cur = Self::get(left, Self::size_of(left) - 1);
152
153        Self::set_right(cur, right);
154        Self::update(right);
155        Self::update(cur);
156
157        cur
158    }
159
160    fn split(root: *mut Self, index: usize) -> (*mut Self, *mut Self) {
161        if root.is_null() {
162            return (ptr::null_mut(), ptr::null_mut());
163        }
164        if index >= Self::size_of(root) {
165            return (root, ptr::null_mut());
166        }
167
168        let cur = Self::get(root, index);
169        let left = Self::left_of(cur).unwrap();
170
171        if !left.is_null() {
172            unsafe {
173                (*left).par = ptr::null_mut();
174            }
175            Self::update(left);
176        }
177        assert!(!cur.is_null());
178        unsafe {
179            (*cur).lc = ptr::null_mut();
180        }
181        Self::update(cur);
182
183        (left, cur)
184    }
185
186    fn traverse(cur: *mut Self, f: &mut impl FnMut(&K, &mut V)) {
187        if !cur.is_null() {
188            Self::pushdown(cur);
189            Self::traverse(Self::left_of(cur).unwrap(), f);
190            f(unsafe { &(*cur).key }, unsafe { &mut (*cur).value });
191            Self::traverse(Self::right_of(cur).unwrap(), f);
192        }
193    }
194
195    fn set_left(this: *mut Self, left: *mut Self) {
196        assert!(!this.is_null());
197        unsafe {
198            (*this).lc = left;
199            if !left.is_null() {
200                (*left).par = this;
201            }
202        }
203    }
204
205    fn set_right(this: *mut Self, right: *mut Self) {
206        assert!(!this.is_null());
207        unsafe {
208            (*this).rc = right;
209            if !right.is_null() {
210                (*right).par = this;
211            }
212        }
213    }
214
215    fn set_par(this: *mut Self, par: *mut Self) {
216        if !this.is_null() {
217            unsafe {
218                (*this).par = par;
219            }
220        }
221    }
222
223    fn size_of(this: *mut Self) -> usize {
224        if this.is_null() {
225            0
226        } else {
227            unsafe { (*this).size }
228        }
229    }
230
231    fn left_of(this: *mut Self) -> Option<*mut Self> {
232        (!this.is_null()).then_some(unsafe { (*this).lc })
233    }
234
235    fn right_of(this: *mut Self) -> Option<*mut Self> {
236        (!this.is_null()).then_some(unsafe { (*this).rc })
237    }
238
239    fn get_par(this: *mut Self) -> Option<*mut Self> {
240        (!this.is_null()).then_some(unsafe { (*this).par })
241    }
242
243    fn clear(this: *mut Self) {
244        if !this.is_null() {
245            let lc = Self::left_of(this).unwrap();
246            let rc = Self::right_of(this).unwrap();
247
248            let _ = unsafe { Box::from_raw(this) };
249
250            Self::clear(lc);
251            Self::clear(rc);
252        }
253    }
254
255    fn key_of<'a>(this: *mut Self) -> Option<&'a K> {
256        (!this.is_null()).then(|| unsafe { &(*this).key })
257    }
258}
259
260impl<K: Ord, V> Node<K, V> {
261    fn binary_search(this: *mut Self, key: &K) -> (*mut Self, Result<usize, usize>) {
262        let mut cur = this;
263        let mut index = 0;
264        let mut prev = ptr::null_mut();
265
266        while !cur.is_null() {
267            let left = Self::left_of(cur).unwrap();
268            let c = Self::size_of(left);
269            prev = cur;
270
271            match Self::key_of(cur).unwrap().cmp(key) {
272                Ordering::Equal => {
273                    Self::splay(cur);
274                    return (cur, Ok(index + c));
275                }
276                Ordering::Greater => {
277                    cur = left;
278                }
279                Ordering::Less => {
280                    cur = Self::right_of(cur).unwrap();
281                    index += c + 1;
282                }
283            }
284        }
285
286        if !prev.is_null() {
287            Self::splay(prev);
288        }
289        (prev, Err(index))
290    }
291}
292
293/// 順序付き辞書
294pub struct OrderedMap<K, V> {
295    root: Cell<*mut Node<K, V>>,
296}
297
298impl<K: Ord, V> OrderedMap<K, V> {
299    /// 空の`OrderedMap`を返す。
300    pub fn new() -> Self {
301        Self {
302            root: Cell::new(ptr::null_mut()),
303        }
304    }
305
306    /// 要素数を返す。
307    pub fn len(&self) -> usize {
308        Node::size_of(self.root.get())
309    }
310
311    /// 要素数が`0`ならば`true`を返す。
312    pub fn is_empty(&self) -> bool {
313        self.root.get().is_null()
314    }
315
316    /// `key`が存在するとき、それが何番目のキーであるかを`Ok`で返す。
317    /// そうでないとき、仮に`key`があったとき何番目のキーであったか、を`Err`で返す。
318    pub fn binary_search(&self, key: &K) -> Result<usize, usize> {
319        let (root, index) = Node::binary_search(self.root.get(), key);
320        self.root.set(root);
321        index
322    }
323
324    /// `key`以下の最大のキーをもつキーと値のペアを返す。
325    pub fn max_le(&self, key: &K) -> Option<(&K, &V)> {
326        match self.binary_search(key) {
327            Ok(i) => self.get_by_index(i),
328            Err(i) => {
329                if i > 0 {
330                    self.get_by_index(i - 1)
331                } else {
332                    None
333                }
334            }
335        }
336    }
337
338    /// `key`以上の最小のキーをもつキーと値のペアを返す。
339    pub fn min_ge(&self, key: &K) -> Option<(&K, &V)> {
340        match self.binary_search(key) {
341            Ok(i) | Err(i) => self.get_by_index(i),
342        }
343    }
344
345    /// `key`をキーとして持つならば`true`を返す。
346    pub fn contains(&self, key: &K) -> bool {
347        self.binary_search(key).is_ok()
348    }
349
350    /// `key`がすでに存在している場合、値を`value`で更新して、古い値を`Some`で返す。
351    /// そうでないとき、`key`に`value`を紐付けて、`None`を返す。
352    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
353        let (new_root, index) = Node::binary_search(self.root.get(), &key);
354        match index {
355            Ok(_) => {
356                let old = Node::set_value(new_root, value);
357                self.root.set(new_root);
358                Some(old)
359            }
360            Err(i) => {
361                let (l, r) = Node::split(new_root, i);
362                let node = Box::into_raw(Box::new(Node::new(key, value)));
363                let root = Node::merge(l, Node::merge(node, r));
364                self.root.set(root);
365                None
366            }
367        }
368    }
369
370    /// キー`key`に対応する値の参照を返す。
371    pub fn get(&self, key: &K) -> Option<&V> {
372        let k = self.binary_search(key).ok()?;
373        self.get_by_index(k).map(|(_, v)| v)
374    }
375
376    /// キー`key`に対応する値の可変参照を返す
377    pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
378        let k = self.binary_search(key).ok()?;
379        self.get_value_mut_by_index(k)
380    }
381
382    /// キー`key`があれば、そのキーと対応する値を削除して、その値を`Some`で返す。
383    pub fn remove(&mut self, key: &K) -> Option<V> {
384        let (m, index) = Node::binary_search(self.root.get(), key);
385
386        if index.is_err() {
387            self.root.set(m);
388            return None;
389        }
390
391        assert!(!m.is_null());
392        let left = Node::left_of(m).unwrap();
393        Node::set_par(left, ptr::null_mut());
394        let right = Node::right_of(m).unwrap();
395        Node::set_par(right, ptr::null_mut());
396
397        self.root.set(Node::merge(left, right));
398
399        (!m.is_null()).then(|| unsafe {
400            let m = Box::from_raw(m);
401            m.value
402        })
403    }
404
405    /// `i`番目のキーとその対応する値のペアへの参照を返す。
406    pub fn get_by_index(&self, i: usize) -> Option<(&K, &V)> {
407        if i >= self.len() {
408            None
409        } else {
410            let t = Node::get(self.root.get(), i);
411            self.root.set(t);
412            (!t.is_null()).then(|| unsafe { (&(*t).key, &(*t).value) })
413        }
414    }
415
416    /// `i`番目のキーへの参照を返す。
417    pub fn get_key_by_index(&self, i: usize) -> Option<&K> {
418        self.get_by_index(i).map(|(k, _)| k)
419    }
420
421    /// `i`番目のキーに対応する値への参照を返す。
422    pub fn get_value_by_index(&self, i: usize) -> Option<&V> {
423        self.get_by_index(i).map(|(_, v)| v)
424    }
425
426    /// `i`番目のキーに対応する値への可変参照を返す。
427    pub fn get_value_mut_by_index(&mut self, i: usize) -> Option<&mut V> {
428        if i >= self.len() {
429            None
430        } else {
431            let t = Node::get(self.root.get(), i);
432            self.root.set(t);
433            (!t.is_null()).then(|| unsafe { &mut (*t).value })
434        }
435    }
436
437    /// `i`番目の要素を削除して、そのキーと値のペアを返す。
438    pub fn remove_by_index(&mut self, i: usize) -> Option<(K, V)> {
439        let (l, r) = Node::split(self.root.get(), i);
440        let (m, r) = Node::split(r, 1);
441        self.root.set(Node::merge(l, r));
442
443        (!m.is_null()).then(|| unsafe {
444            let m = Box::from_raw(m);
445            let node = *m;
446            (node.key, node.value)
447        })
448    }
449
450    /// 順序付き辞書のすべての要素を順番に`f`に渡す。
451    pub fn for_each(&self, mut f: impl FnMut(&K, &mut V)) {
452        Node::traverse(self.root.get(), &mut f);
453    }
454
455    // pub fn pop_first(&mut self) -> Option<V>
456    // pub fn pop_last(&mut self) -> Option<V>
457    // pub fn first(&self) -> Option<&V>
458    // pub fn last(&self) -> Option<&V>
459    // pub fn first_mut(&mut self) -> Option<&mut V>
460    // pub fn last_mut(&mut self) -> Option<&mut V>
461}
462
463impl<K, V> std::ops::Drop for OrderedMap<K, V> {
464    fn drop(&mut self) {
465        Node::clear(self.root.get());
466    }
467}
468
469impl<K: Ord, V> Default for OrderedMap<K, V> {
470    fn default() -> Self {
471        Self::new()
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use rand::Rng;
478    use std::collections::BTreeMap;
479
480    use super::*;
481
482    #[test]
483    fn test() {
484        let mut rng = rand::thread_rng();
485
486        let mut map = OrderedMap::<u32, u32>::new();
487        let mut ans = BTreeMap::<u32, u32>::new();
488
489        let q = 10000;
490
491        for _ in 0..q {
492            let x: u32 = rng.gen_range(0..1000);
493            let y: u32 = rng.gen();
494
495            assert_eq!(map.insert(x, y), ans.insert(x, y));
496
497            let x = rng.gen_range(0..1000);
498
499            assert_eq!(map.remove(&x), ans.remove(&x));
500
501            let x = rng.gen_range(0..1000);
502
503            assert_eq!(map.get(&x), ans.get(&x));
504
505            assert_eq!(map.len(), ans.len());
506        }
507    }
508}