haar_lib/ds/
euler_tour_tree.rs

1//! Euler tour tree
2//!
3//! # References
4//! - <https://qiita.com/hotman78/items/78cd3aa50b05a57738d4>
5//!
6//! # Problems
7//! - <https://judge.yosupo.jp/problem/dynamic_tree_vertex_add_subtree_sum>
8
9use std::collections::HashMap;
10use std::ptr;
11
12use crate::algebra::traits::Monoid;
13
14struct Node<M> {
15    value: M,
16    sum: M,
17    size: usize,
18    lc: *mut Node<M>,
19    rc: *mut Node<M>,
20    par: *mut Node<M>,
21}
22
23impl<M: Monoid + Clone> Node<M> {
24    fn new(value: M) -> Self {
25        Self {
26            value,
27            sum: M::id(),
28            size: 1,
29            lc: ptr::null_mut(),
30            rc: ptr::null_mut(),
31            par: ptr::null_mut(),
32        }
33    }
34
35    fn get_sum(this: *mut Self) -> M {
36        assert!(!this.is_null());
37        unsafe { (*this).sum.clone() }
38    }
39
40    fn set_value(this: *mut Self, value: M) {
41        assert!(!this.is_null());
42        unsafe {
43            (*this).value = value;
44        }
45    }
46
47    fn update_value(this: *mut Self, value: M) {
48        assert!(!this.is_null());
49        unsafe {
50            (*this).value = (*this).value.clone().op(value);
51        }
52    }
53
54    fn rotate(this: *mut Self) {
55        let p = Self::get_par(this).unwrap();
56        let pp = Self::get_par(p).unwrap();
57
58        if Self::left_of(p).unwrap() == this {
59            let c = Self::right_of(this).unwrap();
60            Self::set_left(p, c);
61            Self::set_right(this, p);
62        } else {
63            let c = Self::left_of(this).unwrap();
64            Self::set_right(p, c);
65            Self::set_left(this, p);
66        }
67
68        unsafe {
69            if !pp.is_null() {
70                if (*pp).lc == p {
71                    (*pp).lc = this;
72                }
73                if (*pp).rc == p {
74                    (*pp).rc = this;
75                }
76            }
77
78            assert!(!this.is_null());
79            (*this).par = pp;
80        }
81
82        Self::update(p);
83        Self::update(this);
84    }
85
86    fn status(this: *mut Self) -> i32 {
87        let par = Self::get_par(this).unwrap();
88
89        if par.is_null() {
90            return 0;
91        }
92        if unsafe { (*par).lc } == this {
93            return 1;
94        }
95        if unsafe { (*par).rc } == this {
96            return -1;
97        }
98
99        unreachable!()
100    }
101
102    fn pushdown(this: *mut Self) {
103        if !this.is_null() {
104            Self::update(this);
105        }
106    }
107
108    fn update(this: *mut Self) {
109        assert!(!this.is_null());
110        unsafe {
111            (*this).size = 1 + Self::size_of((*this).lc) + Self::size_of((*this).rc);
112
113            (*this).sum = (*this).value.clone();
114            if !(*this).lc.is_null() {
115                (*this).sum = M::op(Self::get_sum(this), Self::get_sum((*this).lc));
116            }
117            if !(*this).rc.is_null() {
118                (*this).sum = M::op(Self::get_sum(this), Self::get_sum((*this).rc));
119            }
120        }
121    }
122
123    fn splay(this: *mut Self) {
124        while Self::status(this) != 0 {
125            let par = Self::get_par(this).unwrap();
126
127            if Self::status(par) == 0 {
128                Self::rotate(this);
129            } else if Self::status(this) == Self::status(par) {
130                Self::rotate(par);
131                Self::rotate(this);
132            } else {
133                Self::rotate(this);
134                Self::rotate(this);
135            }
136        }
137    }
138
139    fn get_first(root: *mut Self) -> *mut Self {
140        if root.is_null() {
141            return root;
142        }
143
144        let mut cur = root;
145
146        loop {
147            Self::pushdown(cur);
148
149            let left = Self::left_of(cur).unwrap();
150
151            if left.is_null() {
152                Self::splay(cur);
153                return cur;
154            }
155            cur = left;
156        }
157    }
158
159    fn get_last(root: *mut Self) -> *mut Self {
160        if root.is_null() {
161            return root;
162        }
163
164        let mut cur = root;
165
166        loop {
167            Self::pushdown(cur);
168
169            let right = Self::right_of(cur).unwrap();
170
171            if right.is_null() {
172                Self::splay(cur);
173                return cur;
174            }
175            cur = right;
176        }
177    }
178
179    fn merge(left: *mut Self, right: *mut Self) -> *mut Self {
180        if left.is_null() {
181            return right;
182        }
183        if right.is_null() {
184            return left;
185        }
186
187        let cur = Self::get_last(left);
188
189        Self::set_right(cur, right);
190        Self::update(right);
191        Self::update(cur);
192
193        cur
194    }
195
196    fn split_left(root: *mut Self) -> (*mut Self, *mut Self) {
197        if root.is_null() {
198            return (ptr::null_mut(), ptr::null_mut());
199        }
200
201        let cur = root;
202        let left = Self::left_of(cur).unwrap();
203
204        if !left.is_null() {
205            unsafe {
206                (*left).par = ptr::null_mut();
207            }
208            Self::update(left);
209        }
210        assert!(!cur.is_null());
211        unsafe {
212            (*cur).lc = ptr::null_mut();
213        }
214        Self::update(cur);
215
216        (left, cur)
217    }
218
219    fn split_right(root: *mut Self) -> (*mut Self, *mut Self) {
220        if root.is_null() {
221            return (ptr::null_mut(), ptr::null_mut());
222        }
223
224        let cur = root;
225        let right = Self::right_of(cur).unwrap();
226
227        if !right.is_null() {
228            unsafe {
229                (*right).par = ptr::null_mut();
230            }
231            Self::update(right);
232        }
233        assert!(!cur.is_null());
234        unsafe {
235            (*cur).rc = ptr::null_mut();
236        }
237        Self::update(cur);
238
239        (cur, right)
240    }
241
242    fn set_left(this: *mut Self, left: *mut Self) {
243        assert!(!this.is_null());
244        unsafe {
245            (*this).lc = left;
246            if !left.is_null() {
247                (*left).par = this;
248            }
249        }
250    }
251
252    fn set_right(this: *mut Self, right: *mut Self) {
253        assert!(!this.is_null());
254        unsafe {
255            (*this).rc = right;
256            if !right.is_null() {
257                (*right).par = this;
258            }
259        }
260    }
261
262    fn size_of(this: *mut Self) -> usize {
263        if this.is_null() {
264            0
265        } else {
266            unsafe { (*this).size }
267        }
268    }
269
270    fn left_of(this: *mut Self) -> Option<*mut Self> {
271        (!this.is_null()).then_some(unsafe { (*this).lc })
272    }
273
274    fn right_of(this: *mut Self) -> Option<*mut Self> {
275        (!this.is_null()).then_some(unsafe { (*this).rc })
276    }
277
278    fn get_par(this: *mut Self) -> Option<*mut Self> {
279        (!this.is_null()).then_some(unsafe { (*this).par })
280    }
281}
282
283/// Euler tour tree
284pub struct EulerTourTree<M> {
285    vertices: Vec<*mut Node<M>>,
286    edges: Vec<HashMap<usize, *mut Node<M>>>,
287}
288
289impl<M: Monoid + Clone> EulerTourTree<M> {
290    /// `n`個の頂点のみからなる森を構築する。
291    pub fn new(n: usize) -> Self {
292        let vertices = (0..n)
293            .map(|_| {
294                let p = Box::new(Node::new(M::id()));
295                Box::into_raw(p)
296            })
297            .collect::<Vec<_>>();
298
299        let edges = (0..n).map(|i| HashMap::from([(i, vertices[i])])).collect();
300
301        Self { vertices, edges }
302    }
303
304    /// 頂点`r`をそれの属する木の根にする。
305    pub fn reroot(&mut self, r: usize) {
306        let p = self.vertices[r];
307
308        Node::splay(p);
309        let (l, r) = Node::split_left(p);
310        Node::merge(r, l);
311    }
312
313    /// 2つの頂点が同一の木に属するかどうかを判定する。
314    pub fn is_same_tree(&self, i: usize, j: usize) -> bool {
315        if i == j {
316            return true;
317        }
318
319        let pi = self.vertices[i];
320        let pj = self.vertices[j];
321
322        Node::splay(pi);
323        let ri = Node::get_first(pi);
324
325        Node::splay(pj);
326        let rj = Node::get_first(pj);
327
328        ptr::eq(ri, rj)
329    }
330
331    /// 異なる木にそれぞれ属する2頂点間に辺を張る。
332    pub fn link(&mut self, i: usize, j: usize) -> Result<(), &'static str> {
333        if self.is_same_tree(i, j) {
334            return Err("既に同一の木に属している。");
335        }
336
337        self.reroot(i);
338        self.reroot(j);
339
340        let pi = self.vertices[i];
341        let pj = self.vertices[j];
342
343        Node::splay(pi);
344        Node::splay(pj);
345
346        let eij = Box::into_raw(Box::new(Node::new(M::id())));
347        self.edges[i].insert(j, eij);
348
349        let eji = Box::into_raw(Box::new(Node::new(M::id())));
350        self.edges[j].insert(i, eji);
351
352        let t = Node::merge(pi, eij);
353        let t = Node::merge(t, pj);
354        Node::merge(t, eji);
355
356        Ok(())
357    }
358
359    /// 2頂点間を張る辺を削除する。
360    pub fn cut(&mut self, i: usize, j: usize) -> Result<(), &'static str> {
361        if i == j {
362            return Err("同一頂点で`cut`は不可。");
363        }
364        match (self.edges[i].get(&j), self.edges[j].get(&i)) {
365            (Some(&eij), Some(&eji)) => {
366                self.reroot(i);
367
368                Node::splay(eij);
369                let (s, a) = Node::split_left(eij);
370                Node::split_right(a);
371
372                Node::splay(eji);
373                let (_, a) = Node::split_left(eji);
374                let (_, u) = Node::split_right(a);
375
376                Node::merge(s, u);
377
378                self.edges[i].remove(&j);
379                self.edges[j].remove(&i);
380
381                unsafe {
382                    let _ = Box::from_raw(eij);
383                    let _ = Box::from_raw(eji);
384                }
385            }
386            _ => return Err("2頂点をつなぐ辺が存在しない。"),
387        }
388
389        Ok(())
390    }
391
392    /// 頂点`i`の値を`value`に設定する。
393    pub fn set(&mut self, i: usize, value: M) {
394        let p = self.vertices[i];
395        Node::splay(p);
396        Node::set_value(p, value);
397        Node::pushdown(p);
398    }
399
400    /// 頂点`i`の値をモノイドの演算と値`value`で更新する。
401    pub fn update(&mut self, i: usize, value: M) {
402        let p = self.vertices[i];
403        Node::splay(p);
404        Node::update_value(p, value);
405        Node::pushdown(p);
406    }
407
408    /// 頂点`p`を親とする頂点`v`について、`v`を根とする部分木の値を集積して返す。
409    pub fn subtree_sum(&mut self, v: usize, p: usize) -> Result<M, &'static str> {
410        self.cut(v, p)?;
411
412        let rv = self.vertices[v];
413        Node::splay(rv);
414        let ret = Node::get_sum(rv);
415
416        self.link(v, p)?;
417
418        Ok(ret)
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use crate::algebra::trivial::Trivial;
425
426    use super::*;
427
428    #[test]
429    fn test() {
430        let mut ett = EulerTourTree::<Trivial>::new(10);
431
432        ett.link(1, 2).unwrap();
433        ett.link(3, 5).unwrap();
434        ett.link(1, 5).unwrap();
435        ett.reroot(2);
436
437        ett.cut(1, 2).unwrap();
438    }
439}