haar_lib/ds/
ordered_map.rs

1//! 順序付き辞書
2//!
3//! # Problems
4//! - <https://judge.yosupo.jp/problem/ordered_set>
5
6use std::cell::Cell;
7use std::cmp::Ordering;
8use std::ptr;
9
10struct Node<K, V> {
11    key: K,
12    value: V,
13    size: usize,
14    lc: *mut Self,
15    rc: *mut Self,
16    par: *mut Self,
17}
18
19impl<K, V> Node<K, V> {
20    fn new(key: K, value: V) -> Self {
21        Self {
22            key,
23            value,
24            size: 1,
25            lc: ptr::null_mut(),
26            rc: ptr::null_mut(),
27            par: ptr::null_mut(),
28        }
29    }
30
31    fn set_value(this: *mut Self, mut value: V) -> V {
32        assert!(!this.is_null());
33        std::mem::swap(unsafe { &mut (*this).value }, &mut value);
34        value
35    }
36
37    fn rotate(this: *mut Self) {
38        let p = Self::get_par(this).unwrap();
39        let pp = Self::get_par(p).unwrap();
40
41        if Self::left_of(p).unwrap() == this {
42            let c = Self::right_of(this).unwrap();
43            Self::set_left(p, c);
44            Self::set_right(this, p);
45        } else {
46            let c = Self::left_of(this).unwrap();
47            Self::set_right(p, c);
48            Self::set_left(this, p);
49        }
50
51        unsafe {
52            if !pp.is_null() {
53                if (*pp).lc == p {
54                    (*pp).lc = this;
55                }
56                if (*pp).rc == p {
57                    (*pp).rc = this;
58                }
59            }
60
61            assert!(!this.is_null());
62            (*this).par = pp;
63        }
64
65        Self::update(p);
66        Self::update(this);
67    }
68
69    fn status(this: *mut Self) -> i32 {
70        let par = Self::get_par(this).unwrap();
71
72        if par.is_null() {
73            return 0;
74        }
75        if unsafe { (*par).lc } == this {
76            return 1;
77        }
78        if unsafe { (*par).rc } == this {
79            return -1;
80        }
81
82        unreachable!()
83    }
84
85    fn pushdown(this: *mut Self) {
86        if !this.is_null() {
87            Self::update(this);
88        }
89    }
90
91    fn update(this: *mut Self) {
92        assert!(!this.is_null());
93        unsafe {
94            (*this).size = 1 + Self::size_of((*this).lc) + Self::size_of((*this).rc);
95        }
96    }
97
98    fn splay(this: *mut Self) {
99        while Self::status(this) != 0 {
100            let par = Self::get_par(this).unwrap();
101
102            if Self::status(par) == 0 {
103                Self::rotate(this);
104            } else if Self::status(this) == Self::status(par) {
105                Self::rotate(par);
106                Self::rotate(this);
107            } else {
108                Self::rotate(this);
109                Self::rotate(this);
110            }
111        }
112    }
113
114    fn get(root: *mut Self, mut index: usize) -> *mut Self {
115        if root.is_null() {
116            return root;
117        }
118
119        let mut cur = root;
120
121        loop {
122            Self::pushdown(cur);
123
124            let left = Self::left_of(cur).unwrap();
125            let lsize = Self::size_of(left);
126
127            match index.cmp(&lsize) {
128                Ordering::Less => {
129                    cur = left;
130                }
131                Ordering::Equal => {
132                    Self::splay(cur);
133                    return cur;
134                }
135                Ordering::Greater => {
136                    cur = Self::right_of(cur).unwrap();
137                    index -= lsize + 1;
138                }
139            }
140        }
141    }
142
143    fn merge(left: *mut Self, right: *mut Self) -> *mut Self {
144        if left.is_null() {
145            return right;
146        }
147        if right.is_null() {
148            return left;
149        }
150
151        let cur = Self::get(left, Self::size_of(left) - 1);
152
153        Self::set_right(cur, right);
154        Self::update(right);
155        Self::update(cur);
156
157        cur
158    }
159
160    fn split(root: *mut Self, index: usize) -> (*mut Self, *mut Self) {
161        if root.is_null() {
162            return (ptr::null_mut(), ptr::null_mut());
163        }
164        if index >= Self::size_of(root) {
165            return (root, ptr::null_mut());
166        }
167
168        let cur = Self::get(root, index);
169        let left = Self::left_of(cur).unwrap();
170
171        if !left.is_null() {
172            unsafe {
173                (*left).par = ptr::null_mut();
174            }
175            Self::update(left);
176        }
177        assert!(!cur.is_null());
178        unsafe {
179            (*cur).lc = ptr::null_mut();
180        }
181        Self::update(cur);
182
183        (left, cur)
184    }
185
186    fn traverse(cur: *mut Self, f: &mut impl FnMut(&K, &mut V)) {
187        if !cur.is_null() {
188            Self::pushdown(cur);
189            Self::traverse(Self::left_of(cur).unwrap(), f);
190            f(unsafe { &(*cur).key }, unsafe { &mut (*cur).value });
191            Self::traverse(Self::right_of(cur).unwrap(), f);
192        }
193    }
194
195    fn set_left(this: *mut Self, left: *mut Self) {
196        assert!(!this.is_null());
197        unsafe {
198            (*this).lc = left;
199            if !left.is_null() {
200                (*left).par = this;
201            }
202        }
203    }
204
205    fn set_right(this: *mut Self, right: *mut Self) {
206        assert!(!this.is_null());
207        unsafe {
208            (*this).rc = right;
209            if !right.is_null() {
210                (*right).par = this;
211            }
212        }
213    }
214
215    fn size_of(this: *mut Self) -> usize {
216        if this.is_null() {
217            0
218        } else {
219            unsafe { (*this).size }
220        }
221    }
222
223    fn left_of(this: *mut Self) -> Option<*mut Self> {
224        (!this.is_null()).then_some(unsafe { (*this).lc })
225    }
226
227    fn right_of(this: *mut Self) -> Option<*mut Self> {
228        (!this.is_null()).then_some(unsafe { (*this).rc })
229    }
230
231    fn get_par(this: *mut Self) -> Option<*mut Self> {
232        (!this.is_null()).then_some(unsafe { (*this).par })
233    }
234
235    fn clear(this: *mut Self) {
236        if !this.is_null() {
237            let lc = Self::left_of(this).unwrap();
238            let rc = Self::right_of(this).unwrap();
239
240            let _ = unsafe { Box::from_raw(this) };
241
242            Self::clear(lc);
243            Self::clear(rc);
244        }
245    }
246
247    fn key_of<'a>(this: *mut Self) -> Option<&'a K> {
248        (!this.is_null()).then(|| unsafe { &(*this).key })
249    }
250}
251
252impl<K: Ord, V> Node<K, V> {
253    fn binary_search(this: *mut Self, key: &K) -> Result<usize, usize> {
254        if this.is_null() {
255            Err(0)
256        } else {
257            let left = Self::left_of(this).unwrap();
258            let right = Self::right_of(this).unwrap();
259            let c = Self::size_of(left);
260            match Self::key_of(this).unwrap().cmp(key) {
261                Ordering::Equal => Ok(c),
262                Ordering::Greater => Self::binary_search(left, key),
263                Ordering::Less => Self::binary_search(right, key)
264                    .map(|a| a + c + 1)
265                    .map_err(|a| a + c + 1),
266            }
267        }
268    }
269}
270
271/// 順序付き辞書
272pub struct OrderedMap<K, V> {
273    root: Cell<*mut Node<K, V>>,
274}
275
276impl<K: Ord, V> OrderedMap<K, V> {
277    /// 空の`OrderedMap`を返す。
278    pub fn new() -> Self {
279        Self {
280            root: Cell::new(ptr::null_mut()),
281        }
282    }
283
284    /// 要素数を返す。
285    pub fn len(&self) -> usize {
286        Node::size_of(self.root.get())
287    }
288
289    /// 要素数が`0`ならば`true`を返す。
290    pub fn is_empty(&self) -> bool {
291        self.root.get().is_null()
292    }
293
294    /// `key`が存在するとき、それが何番目のキーであるかを`Ok`で返す。
295    /// そうでないとき、仮に`key`があったとき何番目のキーであったか、を`Err`で返す。
296    pub fn binary_search(&self, key: &K) -> Result<usize, usize> {
297        Node::binary_search(self.root.get(), key)
298    }
299
300    /// `key`以下の最大のキーをもつキーと値のペアを返す。
301    pub fn max_le(&self, key: &K) -> Option<(&K, &V)> {
302        match self.binary_search(key) {
303            Ok(i) => self.get_by_index(i),
304            Err(i) => {
305                if i > 0 {
306                    self.get_by_index(i - 1)
307                } else {
308                    None
309                }
310            }
311        }
312    }
313
314    /// `key`以上の最小のキーをもつキーと値のペアを返す。
315    pub fn min_ge(&self, key: &K) -> Option<(&K, &V)> {
316        match self.binary_search(key) {
317            Ok(i) | Err(i) => self.get_by_index(i),
318        }
319    }
320
321    /// `key`をキーとして持つならば`true`を返す。
322    pub fn contains(&self, key: &K) -> bool {
323        Node::binary_search(self.root.get(), key).is_ok()
324    }
325
326    /// `key`がすでに存在している場合、値を`value`で更新して、古い値を`Some`で返す。
327    /// そうでないとき、`key`に`value`を紐付けて、`None`を返す。
328    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
329        match Node::binary_search(self.root.get(), &key) {
330            Ok(i) => {
331                let (l, r) = Node::split(self.root.get(), i);
332                let (m, r) = Node::split(r, 1);
333                let old = Node::set_value(m, value);
334
335                let r = Node::merge(m, r);
336                let root = Node::merge(l, r);
337                self.root.set(root);
338                Some(old)
339            }
340            Err(i) => {
341                let (l, r) = Node::split(self.root.get(), i);
342                let node = Box::into_raw(Box::new(Node::new(key, value)));
343                let root = Node::merge(l, Node::merge(node, r));
344                self.root.set(root);
345                None
346            }
347        }
348    }
349
350    /// キー`key`に対応する値の参照を返す。
351    pub fn get(&self, key: &K) -> Option<&V> {
352        let k = Node::binary_search(self.root.get(), key).ok()?;
353        self.get_by_index(k).map(|(_, v)| v)
354    }
355
356    /// キー`key`に対応する値の可変参照を返す
357    pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
358        let k = Node::binary_search(self.root.get(), key).ok()?;
359        self.get_value_mut_by_index(k)
360    }
361
362    /// キー`key`があれば、そのキーと対応する値を削除して、その値を`Some`で返す。
363    pub fn remove(&mut self, key: &K) -> Option<V> {
364        let i = Node::binary_search(self.root.get(), key).ok()?;
365        self.remove_by_index(i).map(|(_, v)| v)
366    }
367
368    /// `i`番目のキーとその対応する値のペアへの参照を返す。
369    pub fn get_by_index(&self, i: usize) -> Option<(&K, &V)> {
370        if i >= self.len() {
371            None
372        } else {
373            let t = Node::get(self.root.get(), i);
374            self.root.set(t);
375            (!t.is_null()).then(|| unsafe { (&(*t).key, &(*t).value) })
376        }
377    }
378
379    /// `i`番目のキーへの参照を返す。
380    pub fn get_key_by_index(&self, i: usize) -> Option<&K> {
381        self.get_by_index(i).map(|(k, _)| k)
382    }
383
384    /// `i`番目のキーに対応する値への参照を返す。
385    pub fn get_value_by_index(&self, i: usize) -> Option<&V> {
386        self.get_by_index(i).map(|(_, v)| v)
387    }
388
389    /// `i`番目のキーに対応する値への可変参照を返す。
390    pub fn get_value_mut_by_index(&mut self, i: usize) -> Option<&mut V> {
391        if i >= self.len() {
392            None
393        } else {
394            let t = Node::get(self.root.get(), i);
395            self.root.set(t);
396            (!t.is_null()).then(|| unsafe { &mut (*t).value })
397        }
398    }
399
400    /// `i`番目の要素を削除して、そのキーと値のペアを返す。
401    pub fn remove_by_index(&mut self, i: usize) -> Option<(K, V)> {
402        let (l, r) = Node::split(self.root.get(), i);
403        let (m, r) = Node::split(r, 1);
404        self.root.set(Node::merge(l, r));
405
406        (!m.is_null()).then(|| unsafe {
407            let m = Box::from_raw(m);
408            let node = *m;
409            (node.key, node.value)
410        })
411    }
412
413    /// 順序付き辞書のすべての要素を順番に`f`に渡す。
414    pub fn for_each(&self, mut f: impl FnMut(&K, &mut V)) {
415        Node::traverse(self.root.get(), &mut f);
416    }
417
418    // pub fn pop_first(&mut self) -> Option<V>
419    // pub fn pop_last(&mut self) -> Option<V>
420    // pub fn first(&self) -> Option<&V>
421    // pub fn last(&self) -> Option<&V>
422    // pub fn first_mut(&mut self) -> Option<&mut V>
423    // pub fn last_mut(&mut self) -> Option<&mut V>
424}
425
426impl<K, V> std::ops::Drop for OrderedMap<K, V> {
427    fn drop(&mut self) {
428        Node::clear(self.root.get());
429    }
430}
431
432impl<K: Ord, V> Default for OrderedMap<K, V> {
433    fn default() -> Self {
434        Self::new()
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use rand::Rng;
441    use std::collections::BTreeMap;
442
443    use super::*;
444
445    #[test]
446    fn test() {
447        let mut rng = rand::thread_rng();
448
449        let mut map = OrderedMap::<u32, u32>::new();
450        let mut ans = BTreeMap::<u32, u32>::new();
451
452        let q = 10000;
453
454        for _ in 0..q {
455            let x: u32 = rng.gen_range(0..1000);
456            let y: u32 = rng.gen();
457
458            assert_eq!(map.insert(x, y), ans.insert(x, y));
459
460            let x = rng.gen_range(0..1000);
461
462            assert_eq!(map.remove(&x), ans.remove(&x));
463
464            let x = rng.gen_range(0..1000);
465
466            assert_eq!(map.get(&x), ans.get(&x));
467
468            assert_eq!(map.len(), ans.len());
469        }
470    }
471}