haar_lib/ds/
ordered_map.rs

1//! 順序付き辞書
2//!
3//! # Problems
4//! - <https://judge.yosupo.jp/problem/ordered_set>
5
6use std::cell::Cell;
7use std::cmp::Ordering;
8use std::ptr;
9
10struct Node<K, V> {
11    key: K,
12    value: V,
13    size: usize,
14    lc: *mut Self,
15    rc: *mut Self,
16    par: *mut Self,
17}
18
19impl<K, V> Node<K, V> {
20    fn new(key: K, value: V) -> Self {
21        Self {
22            key,
23            value,
24            size: 1,
25            lc: ptr::null_mut(),
26            rc: ptr::null_mut(),
27            par: ptr::null_mut(),
28        }
29    }
30
31    fn set_value(this: *mut Self, mut value: V) -> V {
32        assert!(!this.is_null());
33        std::mem::swap(unsafe { &mut (*this).value }, &mut value);
34        value
35    }
36
37    fn rotate(this: *mut Self) {
38        let p = Self::par_of(this).unwrap();
39        let pp = Self::par_of(p).unwrap();
40
41        if Self::left_of(p).unwrap() == this {
42            let c = Self::right_of(this).unwrap();
43            Self::set_left(p, c);
44            Self::set_right(this, p);
45        } else {
46            let c = Self::left_of(this).unwrap();
47            Self::set_right(p, c);
48            Self::set_left(this, p);
49        }
50
51        if !pp.is_null() {
52            let pp = unsafe { &mut *pp };
53            if pp.lc == p {
54                pp.lc = this;
55            }
56            if pp.rc == p {
57                pp.rc = this;
58            }
59        }
60
61        assert!(!this.is_null());
62        unsafe { (*this).par = pp };
63
64        Self::update(p);
65        Self::update(this);
66    }
67
68    fn status(this: *mut Self) -> i32 {
69        let par = Self::par_of(this).unwrap();
70        if par.is_null() {
71            return 0;
72        }
73        let par = unsafe { &mut *par };
74        if par.lc == this {
75            return 1;
76        }
77        if par.rc == this {
78            return -1;
79        }
80
81        unreachable!()
82    }
83
84    fn pushdown(this: *mut Self) {
85        if !this.is_null() {
86            Self::update(this);
87        }
88    }
89
90    fn update(this: *mut Self) {
91        assert!(!this.is_null());
92        let this = unsafe { &mut *this };
93        this.size = 1 + Self::size_of(this.lc) + Self::size_of(this.rc);
94    }
95
96    fn splay(this: *mut Self) {
97        while Self::status(this) != 0 {
98            let par = Self::par_of(this).unwrap();
99
100            if Self::status(par) == 0 {
101                Self::rotate(this);
102            } else if Self::status(this) == Self::status(par) {
103                Self::rotate(par);
104                Self::rotate(this);
105            } else {
106                Self::rotate(this);
107                Self::rotate(this);
108            }
109        }
110    }
111
112    fn get(root: *mut Self, mut index: usize) -> *mut Self {
113        if root.is_null() {
114            return root;
115        }
116
117        let mut cur = root;
118
119        loop {
120            Self::pushdown(cur);
121
122            let left = Self::left_of(cur).unwrap();
123            let lsize = Self::size_of(left);
124
125            match index.cmp(&lsize) {
126                Ordering::Less => {
127                    cur = left;
128                }
129                Ordering::Equal => {
130                    Self::splay(cur);
131                    return cur;
132                }
133                Ordering::Greater => {
134                    cur = Self::right_of(cur).unwrap();
135                    index -= lsize + 1;
136                }
137            }
138        }
139    }
140
141    fn merge(left: *mut Self, right: *mut Self) -> *mut Self {
142        if left.is_null() {
143            return right;
144        }
145        if right.is_null() {
146            return left;
147        }
148
149        let cur = Self::get(left, Self::size_of(left) - 1);
150
151        Self::set_right(cur, right);
152        Self::update(right);
153        Self::update(cur);
154
155        cur
156    }
157
158    fn split(root: *mut Self, index: usize) -> (*mut Self, *mut Self) {
159        if root.is_null() {
160            return (ptr::null_mut(), ptr::null_mut());
161        }
162        if index >= Self::size_of(root) {
163            return (root, ptr::null_mut());
164        }
165
166        let cur = Self::get(root, index);
167        let left = Self::left_of(cur).unwrap();
168
169        if !left.is_null() {
170            Self::set_par(left, ptr::null_mut());
171            Self::update(left);
172        }
173        Self::set_left(cur, ptr::null_mut());
174        Self::update(cur);
175
176        (left, cur)
177    }
178
179    fn traverse(cur: *mut Self, f: &mut impl FnMut(&K, &mut V)) {
180        if !cur.is_null() {
181            let cur = unsafe { &mut *cur };
182            Self::pushdown(cur);
183            Self::traverse(Self::left_of(cur).unwrap(), f);
184            f(&cur.key, &mut cur.value);
185            Self::traverse(Self::right_of(cur).unwrap(), f);
186        }
187    }
188
189    fn set_left(this: *mut Self, left: *mut Self) {
190        assert!(!this.is_null());
191        unsafe { (*this).lc = left };
192        if !left.is_null() {
193            unsafe { (*left).par = this };
194        }
195    }
196
197    fn set_right(this: *mut Self, right: *mut Self) {
198        assert!(!this.is_null());
199        unsafe { (*this).rc = right };
200        if !right.is_null() {
201            unsafe { (*right).par = this };
202        }
203    }
204
205    fn set_par(this: *mut Self, par: *mut Self) {
206        if !this.is_null() {
207            unsafe { (*this).par = par };
208        }
209    }
210
211    fn size_of(this: *mut Self) -> usize {
212        if this.is_null() {
213            0
214        } else {
215            unsafe { (*this).size }
216        }
217    }
218
219    fn left_of(this: *mut Self) -> Option<*mut Self> {
220        (!this.is_null()).then(|| unsafe { (*this).lc })
221    }
222
223    fn right_of(this: *mut Self) -> Option<*mut Self> {
224        (!this.is_null()).then(|| unsafe { (*this).rc })
225    }
226
227    fn par_of(this: *mut Self) -> Option<*mut Self> {
228        (!this.is_null()).then(|| unsafe { (*this).par })
229    }
230
231    fn clear(this: *mut Self) {
232        if !this.is_null() {
233            let lc = Self::left_of(this).unwrap();
234            let rc = Self::right_of(this).unwrap();
235
236            let _ = unsafe { Box::from_raw(this) };
237
238            Self::clear(lc);
239            Self::clear(rc);
240        }
241    }
242
243    fn key_of<'a>(this: *mut Self) -> Option<&'a K> {
244        (!this.is_null()).then(|| unsafe { &(*this).key })
245    }
246
247    fn val_of<'a>(this: *mut Self) -> Option<&'a V> {
248        (!this.is_null()).then(|| unsafe { &(*this).value })
249    }
250
251    fn val_mut_of<'a>(this: *mut Self) -> Option<&'a mut V> {
252        (!this.is_null()).then(|| unsafe { &mut (*this).value })
253    }
254
255    fn from_ptr(this: *mut Self) -> Self {
256        assert!(!this.is_null());
257        let this = unsafe { Box::from_raw(this) };
258        *this
259    }
260}
261
262impl<K: Ord, V> Node<K, V> {
263    fn binary_search(this: *mut Self, key: &K) -> (*mut Self, Result<usize, usize>) {
264        let mut cur = this;
265        let mut index = 0;
266        let mut prev = ptr::null_mut();
267
268        while !cur.is_null() {
269            let left = Self::left_of(cur).unwrap();
270            let c = Self::size_of(left);
271            prev = cur;
272
273            match Self::key_of(cur).unwrap().cmp(key) {
274                Ordering::Equal => {
275                    Self::splay(cur);
276                    return (cur, Ok(index + c));
277                }
278                Ordering::Greater => {
279                    cur = left;
280                }
281                Ordering::Less => {
282                    cur = Self::right_of(cur).unwrap();
283                    index += c + 1;
284                }
285            }
286        }
287
288        if !prev.is_null() {
289            Self::splay(prev);
290        }
291        (prev, Err(index))
292    }
293}
294
295/// 順序付き辞書
296pub struct OrderedMap<K, V> {
297    root: Cell<*mut Node<K, V>>,
298}
299
300impl<K: Ord, V> OrderedMap<K, V> {
301    /// 空の`OrderedMap`を返す。
302    pub fn new() -> Self {
303        Self {
304            root: Cell::new(ptr::null_mut()),
305        }
306    }
307
308    /// 要素数を返す。
309    pub fn len(&self) -> usize {
310        Node::size_of(self.root.get())
311    }
312
313    /// 要素数が`0`ならば`true`を返す。
314    pub fn is_empty(&self) -> bool {
315        self.root.get().is_null()
316    }
317
318    /// `key`が存在するとき、それが何番目のキーであるかと値への参照を`Ok`で返す。
319    /// そうでないとき、仮に`key`があったとき何番目のキーであったか、を`Err`で返す。
320    pub fn binary_search(&self, key: &K) -> Result<(usize, &V), usize> {
321        let (root, index) = Node::binary_search(self.root.get(), key);
322        self.root.set(root);
323        assert!(index.is_err() || Node::key_of(root).unwrap() == key);
324        index.map(|i| (i, Node::val_of(root).unwrap()))
325    }
326
327    /// `key`が存在するとき、それが何番目のキーであるかと値への可変参照を`Ok`で返す。
328    /// そうでないとき、仮に`key`があったとき何番目のキーであったか、を`Err`で返す。
329    pub fn binary_search_mut(&self, key: &K) -> Result<(usize, &mut V), usize> {
330        let (root, index) = Node::binary_search(self.root.get(), key);
331        self.root.set(root);
332        assert!(index.is_err() || Node::key_of(root).unwrap() == key);
333        index.map(|i| (i, Node::val_mut_of(root).unwrap()))
334    }
335
336    /// `key`以下の最大のキーをもつキーと値のペアを返す。
337    pub fn max_le<'a>(&'a self, key: &'a K) -> Option<(&'a K, &'a V)> {
338        match self.binary_search(key) {
339            Ok((_, value)) => Some((key, value)),
340            Err(i) => self.get_by_index(i.checked_sub(1)?),
341        }
342    }
343
344    /// `key`以上の最小のキーをもつキーと値のペアを返す。
345    pub fn min_ge<'a>(&'a self, key: &'a K) -> Option<(&'a K, &'a V)> {
346        match self.binary_search(key) {
347            Ok((_, value)) => Some((key, value)),
348            Err(i) => self.get_by_index(i),
349        }
350    }
351
352    /// `key`をキーとして持つならば`true`を返す。
353    pub fn contains(&self, key: &K) -> bool {
354        self.binary_search(key).is_ok()
355    }
356
357    /// `key`がすでに存在している場合、値を`value`で更新して、古い値を`Some`で返す。
358    /// そうでないとき、`key`に`value`を紐付けて、`None`を返す。
359    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
360        let (new_root, index) = Node::binary_search(self.root.get(), &key);
361        match index {
362            Ok(_) => {
363                let old = Node::set_value(new_root, value);
364                self.root.set(new_root);
365                Some(old)
366            }
367            Err(i) => {
368                let (l, r) = Node::split(new_root, i);
369                let node = Box::into_raw(Box::new(Node::new(key, value)));
370                let root = Node::merge(l, Node::merge(node, r));
371                self.root.set(root);
372                None
373            }
374        }
375    }
376
377    /// キー`key`に対応する値の参照を返す。
378    pub fn get(&self, key: &K) -> Option<&V> {
379        let (_, value) = self.binary_search(key).ok()?;
380        Some(value)
381    }
382
383    /// キー`key`に対応する値の可変参照を返す
384    pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
385        let (_, value) = self.binary_search_mut(key).ok()?;
386        Some(value)
387    }
388
389    /// キー`key`があれば、そのキーと対応する値を削除して、その値を`Some`で返す。
390    pub fn remove(&mut self, key: &K) -> Option<V> {
391        let (m, index) = Node::binary_search(self.root.get(), key);
392
393        if index.is_err() {
394            self.root.set(m);
395            return None;
396        }
397
398        assert!(!m.is_null());
399        let left = Node::left_of(m).unwrap();
400        Node::set_par(left, ptr::null_mut());
401        let right = Node::right_of(m).unwrap();
402        Node::set_par(right, ptr::null_mut());
403
404        self.root.set(Node::merge(left, right));
405
406        (!m.is_null()).then(|| Node::from_ptr(m).value)
407    }
408
409    /// `i`番目のキーとその対応する値のペアへの参照を返す。
410    pub fn get_by_index(&self, i: usize) -> Option<(&K, &V)> {
411        if i >= self.len() {
412            None
413        } else {
414            let t = Node::get(self.root.get(), i);
415            self.root.set(t);
416            (!t.is_null()).then(|| (Node::key_of(t).unwrap(), Node::val_of(t).unwrap()))
417        }
418    }
419
420    /// `i`番目のキーとその対応する値の可変参照のペアを返す。
421    pub fn get_mut_by_index(&self, i: usize) -> Option<(&K, &mut V)> {
422        if i >= self.len() {
423            None
424        } else {
425            let t = Node::get(self.root.get(), i);
426            self.root.set(t);
427            (!t.is_null()).then(|| (Node::key_of(t).unwrap(), Node::val_mut_of(t).unwrap()))
428        }
429    }
430
431    /// `i`番目のキーへの参照を返す。
432    pub fn get_key_by_index(&self, i: usize) -> Option<&K> {
433        self.get_by_index(i).map(|(k, _)| k)
434    }
435
436    /// `i`番目のキーに対応する値への参照を返す。
437    pub fn get_value_by_index(&self, i: usize) -> Option<&V> {
438        self.get_by_index(i).map(|(_, v)| v)
439    }
440
441    /// `i`番目のキーに対応する値への可変参照を返す。
442    pub fn get_value_mut_by_index(&mut self, i: usize) -> Option<&mut V> {
443        self.get_mut_by_index(i).map(|(_, v)| v)
444    }
445
446    /// `i`番目の要素を削除して、そのキーと値のペアを返す。
447    pub fn remove_by_index(&mut self, i: usize) -> Option<(K, V)> {
448        let (l, r) = Node::split(self.root.get(), i);
449        let (m, r) = Node::split(r, 1);
450        self.root.set(Node::merge(l, r));
451
452        (!m.is_null()).then(|| {
453            let m = Node::from_ptr(m);
454            (m.key, m.value)
455        })
456    }
457
458    /// 順序付き辞書のすべての要素を順番に`f`に渡す。
459    pub fn for_each(&self, mut f: impl FnMut(&K, &mut V)) {
460        Node::traverse(self.root.get(), &mut f);
461    }
462
463    /// 先頭の要素を削除して返す。
464    pub fn pop_first(&mut self) -> Option<(K, V)> {
465        self.remove_by_index(0)
466    }
467    /// 末尾の要素を削除して返す。
468    pub fn pop_last(&mut self) -> Option<(K, V)> {
469        self.remove_by_index(self.len().checked_sub(1)?)
470    }
471    /// 先頭の要素の参照を返す。
472    pub fn first(&self) -> Option<(&K, &V)> {
473        self.get_by_index(0)
474    }
475    /// 末尾の要素の参照を返す。
476    pub fn last(&self) -> Option<(&K, &V)> {
477        self.get_by_index(self.len().checked_sub(1)?)
478    }
479    /// 先頭の要素の可変参照を返す。
480    pub fn first_mut(&mut self) -> Option<(&K, &mut V)> {
481        self.get_mut_by_index(0)
482    }
483    /// 末尾の要素の可変参照を返す。
484    pub fn last_mut(&mut self) -> Option<(&K, &mut V)> {
485        self.get_mut_by_index(self.len().checked_sub(1)?)
486    }
487}
488
489impl<K, V> std::ops::Drop for OrderedMap<K, V> {
490    fn drop(&mut self) {
491        Node::clear(self.root.get());
492    }
493}
494
495impl<K: Ord, V> Default for OrderedMap<K, V> {
496    fn default() -> Self {
497        Self::new()
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use rand::Rng;
504    use std::collections::BTreeMap;
505
506    use super::*;
507
508    #[test]
509    fn test_empty() {
510        let mut map = OrderedMap::<u32, Vec<u32>>::new();
511
512        assert!(map.first().is_none());
513        assert!(map.last().is_none());
514        assert!(map.first_mut().is_none());
515        assert!(map.last_mut().is_none());
516        assert!(map.pop_first().is_none());
517        assert!(map.pop_last().is_none());
518    }
519
520    #[test]
521    fn test() {
522        let mut rng = rand::thread_rng();
523
524        let mut map = OrderedMap::<u32, Vec<u32>>::new();
525        let mut ans = BTreeMap::<u32, Vec<u32>>::new();
526
527        let q = 10000;
528
529        for _ in 0..q {
530            let x: u32 = rng.gen_range(0..1000);
531            let y: Vec<u32> = vec![rng.gen()];
532
533            assert_eq!(map.insert(x, y.clone()), ans.insert(x, y));
534
535            let x = rng.gen_range(0..1000);
536
537            assert_eq!(map.remove(&x), ans.remove(&x));
538
539            let x = rng.gen_range(0..1000);
540
541            assert_eq!(map.get(&x), ans.get(&x));
542
543            assert_eq!(map.len(), ans.len());
544        }
545    }
546}