haar_lib/ds/
euler_tour_tree.rs

1//! Euler tour tree
2//!
3//! # References
4//! - <https://qiita.com/hotman78/items/78cd3aa50b05a57738d4>
5//!
6//! # Problems
7//! - <https://judge.yosupo.jp/problem/dynamic_tree_vertex_add_subtree_sum>
8
9use std::collections::HashMap;
10use std::ptr;
11
12use crate::algebra::traits::Monoid;
13
14struct Manager<M> {
15    monoid: M,
16}
17
18struct Node<M: Monoid> {
19    value: M::Element,
20    sum: M::Element,
21    size: usize,
22    lc: *mut Self,
23    rc: *mut Self,
24    par: *mut Self,
25}
26
27impl<M: Monoid> Manager<M>
28where
29    M::Element: Clone,
30{
31    fn new(monoid: M) -> Self {
32        Self { monoid }
33    }
34
35    fn create(&self, value: M::Element) -> Node<M> {
36        Node {
37            value,
38            sum: self.monoid.id(),
39            size: 1,
40            lc: ptr::null_mut(),
41            rc: ptr::null_mut(),
42            par: ptr::null_mut(),
43        }
44    }
45
46    fn get_sum(&self, this: *mut Node<M>) -> M::Element {
47        if this.is_null() {
48            self.monoid.id()
49        } else {
50            unsafe { (*this).sum.clone() }
51        }
52    }
53
54    fn set_value(&self, this: *mut Node<M>, value: M::Element) {
55        assert!(!this.is_null());
56        unsafe {
57            (*this).value = value;
58        }
59    }
60
61    fn update_value(&self, this: *mut Node<M>, value: M::Element) {
62        assert!(!this.is_null());
63        let this = unsafe { &mut *this };
64        this.value = self.monoid.op(this.value.clone(), value);
65    }
66
67    fn rotate(&self, this: *mut Node<M>) {
68        let p = self.par_of(this).unwrap();
69        let pp = self.par_of(p).unwrap();
70
71        if self.left_of(p).unwrap() == this {
72            let c = self.right_of(this).unwrap();
73            self.set_left(p, c);
74            self.set_right(this, p);
75        } else {
76            let c = self.left_of(this).unwrap();
77            self.set_right(p, c);
78            self.set_left(this, p);
79        }
80
81        if !pp.is_null() {
82            let pp = unsafe { &mut *pp };
83            if pp.lc == p {
84                pp.lc = this;
85            }
86            if pp.rc == p {
87                pp.rc = this;
88            }
89        }
90
91        assert!(!this.is_null());
92        let this = unsafe { &mut *this };
93        this.par = pp;
94
95        self.update(p);
96        self.update(this);
97    }
98
99    fn status(&self, this: *mut Node<M>) -> i32 {
100        let par = self.par_of(this).unwrap();
101        if par.is_null() {
102            return 0;
103        }
104        let par = unsafe { &*par };
105        if par.lc == this {
106            return 1;
107        }
108        if par.rc == this {
109            return -1;
110        }
111
112        unreachable!()
113    }
114
115    fn pushdown(&self, this: *mut Node<M>) {
116        if !this.is_null() {
117            self.update(this);
118        }
119    }
120
121    fn update(&self, this: *mut Node<M>) {
122        assert!(!this.is_null());
123        let this = unsafe { &mut *this };
124
125        this.size = 1 + self.size_of(this.lc) + self.size_of(this.rc);
126
127        this.sum = this.value.clone();
128        if !this.lc.is_null() {
129            this.sum = self.monoid.op(this.sum.clone(), self.get_sum(this.lc));
130        }
131        if !this.rc.is_null() {
132            this.sum = self.monoid.op(this.sum.clone(), self.get_sum(this.rc));
133        }
134    }
135
136    fn splay(&self, this: *mut Node<M>) {
137        while self.status(this) != 0 {
138            let par = self.par_of(this).unwrap();
139
140            if self.status(par) == 0 {
141                self.rotate(this);
142            } else if self.status(this) == self.status(par) {
143                self.rotate(par);
144                self.rotate(this);
145            } else {
146                self.rotate(this);
147                self.rotate(this);
148            }
149        }
150    }
151
152    fn get_first(&self, root: *mut Node<M>) -> *mut Node<M> {
153        if root.is_null() {
154            return root;
155        }
156
157        let mut cur = root;
158
159        loop {
160            self.pushdown(cur);
161
162            let left = self.left_of(cur).unwrap();
163
164            if left.is_null() {
165                self.splay(cur);
166                return cur;
167            }
168            cur = left;
169        }
170    }
171
172    fn get_last(&self, root: *mut Node<M>) -> *mut Node<M> {
173        if root.is_null() {
174            return root;
175        }
176
177        let mut cur = root;
178
179        loop {
180            self.pushdown(cur);
181
182            let right = self.right_of(cur).unwrap();
183
184            if right.is_null() {
185                self.splay(cur);
186                return cur;
187            }
188            cur = right;
189        }
190    }
191
192    fn merge(&self, left: *mut Node<M>, right: *mut Node<M>) -> *mut Node<M> {
193        if left.is_null() {
194            return right;
195        }
196        if right.is_null() {
197            return left;
198        }
199
200        let cur = self.get_last(left);
201
202        self.set_right(cur, right);
203        self.update(right);
204        self.update(cur);
205
206        cur
207    }
208
209    fn split_left(&self, root: *mut Node<M>) -> (*mut Node<M>, *mut Node<M>) {
210        if root.is_null() {
211            return (ptr::null_mut(), ptr::null_mut());
212        }
213
214        let cur = root;
215        let left = self.left_of(cur).unwrap();
216
217        if !left.is_null() {
218            self.set_par(left, ptr::null_mut());
219            self.update(left);
220        }
221        self.set_left(cur, ptr::null_mut());
222        self.update(cur);
223
224        (left, cur)
225    }
226
227    fn split_right(&self, root: *mut Node<M>) -> (*mut Node<M>, *mut Node<M>) {
228        if root.is_null() {
229            return (ptr::null_mut(), ptr::null_mut());
230        }
231
232        let cur = root;
233        let right = self.right_of(cur).unwrap();
234
235        if !right.is_null() {
236            self.set_par(right, ptr::null_mut());
237            self.update(right);
238        }
239        self.set_right(cur, ptr::null_mut());
240        self.update(cur);
241
242        (cur, right)
243    }
244
245    fn set_left(&self, this: *mut Node<M>, left: *mut Node<M>) {
246        assert!(!this.is_null());
247        unsafe { (*this).lc = left };
248        if !left.is_null() {
249            unsafe { (*left).par = this };
250        }
251    }
252
253    fn set_right(&self, this: *mut Node<M>, right: *mut Node<M>) {
254        assert!(!this.is_null());
255        unsafe { (*this).rc = right };
256        if !right.is_null() {
257            unsafe { (*right).par = this };
258        }
259    }
260
261    fn set_par(&self, this: *mut Node<M>, par: *mut Node<M>) {
262        assert!(!this.is_null());
263        unsafe { (*this).par = par };
264    }
265
266    fn size_of(&self, this: *mut Node<M>) -> usize {
267        if this.is_null() {
268            0
269        } else {
270            unsafe { (*this).size }
271        }
272    }
273
274    fn left_of(&self, this: *mut Node<M>) -> Option<*mut Node<M>> {
275        (!this.is_null()).then(|| unsafe { (*this).lc })
276    }
277
278    fn right_of(&self, this: *mut Node<M>) -> Option<*mut Node<M>> {
279        (!this.is_null()).then(|| unsafe { (*this).rc })
280    }
281
282    fn par_of(&self, this: *mut Node<M>) -> Option<*mut Node<M>> {
283        (!this.is_null()).then(|| unsafe { (*this).par })
284    }
285}
286
287/// Euler tour tree
288pub struct EulerTourTree<M: Monoid> {
289    man: Manager<M>,
290    vertices: Vec<*mut Node<M>>,
291    edges: Vec<HashMap<usize, *mut Node<M>>>,
292}
293
294impl<M: Monoid + Clone> EulerTourTree<M>
295where
296    M::Element: Clone,
297{
298    /// `n`個の頂点のみからなる森を構築する。
299    pub fn new(monoid: M, n: usize) -> Self {
300        let man = Manager::new(monoid);
301        let vertices = std::iter::repeat_with(|| {
302            let p = Box::new(man.create(man.monoid.id()));
303            Box::into_raw(p)
304        })
305        .take(n)
306        .collect::<Vec<_>>();
307
308        let edges = (0..n).map(|i| HashMap::from([(i, vertices[i])])).collect();
309
310        Self {
311            man,
312            vertices,
313            edges,
314        }
315    }
316
317    /// 頂点`r`をそれの属する木の根にする。
318    pub fn reroot(&mut self, r: usize) {
319        let p = self.vertices[r];
320
321        self.man.splay(p);
322        let (l, r) = self.man.split_left(p);
323        self.man.merge(r, l);
324    }
325
326    /// 2つの頂点が同一の木に属するかどうかを判定する。
327    pub fn is_same_tree(&self, i: usize, j: usize) -> bool {
328        if i == j {
329            return true;
330        }
331
332        let pi = self.vertices[i];
333        let pj = self.vertices[j];
334
335        self.man.splay(pi);
336        let ri = self.man.get_first(pi);
337
338        self.man.splay(pj);
339        let rj = self.man.get_first(pj);
340
341        ptr::eq(ri, rj)
342    }
343
344    /// 異なる木にそれぞれ属する2頂点間に辺を張る。
345    pub fn link(&mut self, i: usize, j: usize) -> Result<(), &'static str> {
346        if self.is_same_tree(i, j) {
347            return Err("既に同一の木に属している。");
348        }
349
350        self.reroot(i);
351        self.reroot(j);
352
353        let pi = self.vertices[i];
354        let pj = self.vertices[j];
355
356        self.man.splay(pi);
357        self.man.splay(pj);
358
359        let eij = Box::into_raw(Box::new(self.man.create(self.man.monoid.id())));
360        self.edges[i].insert(j, eij);
361
362        let eji = Box::into_raw(Box::new(self.man.create(self.man.monoid.id())));
363        self.edges[j].insert(i, eji);
364
365        let t = self.man.merge(pi, eij);
366        let t = self.man.merge(t, pj);
367        self.man.merge(t, eji);
368
369        Ok(())
370    }
371
372    /// 2頂点間を張る辺を削除する。
373    pub fn cut(&mut self, i: usize, j: usize) -> Result<(), &'static str> {
374        if i == j {
375            return Err("同一頂点で`cut`は不可。");
376        }
377        match (self.edges[i].get(&j), self.edges[j].get(&i)) {
378            (Some(&eij), Some(&eji)) => {
379                self.reroot(i);
380
381                self.man.splay(eij);
382                let (s, a) = self.man.split_left(eij);
383                self.man.split_right(a);
384
385                self.man.splay(eji);
386                let (_, a) = self.man.split_left(eji);
387                let (_, u) = self.man.split_right(a);
388
389                self.man.merge(s, u);
390
391                self.edges[i].remove(&j);
392                self.edges[j].remove(&i);
393
394                unsafe {
395                    let _ = Box::from_raw(eij);
396                    let _ = Box::from_raw(eji);
397                }
398            }
399            _ => return Err("2頂点をつなぐ辺が存在しない。"),
400        }
401
402        Ok(())
403    }
404
405    /// 頂点`i`の値を`value`に設定する。
406    pub fn set(&mut self, i: usize, value: M::Element) {
407        let p = self.vertices[i];
408        self.man.splay(p);
409        self.man.set_value(p, value);
410        self.man.pushdown(p);
411    }
412
413    /// 頂点`i`の値をモノイドの演算と値`value`で更新する。
414    pub fn update(&mut self, i: usize, value: M::Element) {
415        let p = self.vertices[i];
416        self.man.splay(p);
417        self.man.update_value(p, value);
418        self.man.pushdown(p);
419    }
420
421    /// 頂点`p`を親とする頂点`v`について、`v`を根とする部分木の値を集積して返す。
422    pub fn subtree_sum(&mut self, v: usize, p: usize) -> Result<M::Element, &'static str> {
423        self.cut(v, p)?;
424
425        let rv = self.vertices[v];
426        self.man.splay(rv);
427        let ret = self.man.get_sum(rv);
428
429        self.link(v, p)?;
430
431        Ok(ret)
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use crate::algebra::trivial::Trivial;
438
439    use super::*;
440
441    #[test]
442    fn test() {
443        let mut ett = EulerTourTree::new(Trivial, 10);
444
445        ett.link(1, 2).unwrap();
446        ett.link(3, 5).unwrap();
447        ett.link(1, 5).unwrap();
448        ett.reroot(2);
449
450        ett.cut(1, 2).unwrap();
451    }
452}