haar_lib/tree/
hld.rs

1//! 重軽分解
2use crate::tree::*;
3use std::cmp::max;
4
5/// 重軽分解
6#[derive(Clone, Debug)]
7pub struct HLD {
8    _size: usize,
9    par: Vec<Option<usize>>,
10    head: Vec<usize>,
11    id: Vec<usize>,
12    rid: Vec<usize>,
13    next: Vec<Option<usize>>,
14    end: Vec<usize>,
15}
16
17impl HLD {
18    /// **Time complexity** $O(n)$
19    ///
20    /// **Space complexity** $O(n)$
21    pub fn new<E: TreeEdgeTrait>(tree: &Tree<E>, root: usize) -> Self {
22        let size = tree.len();
23        let mut ret = Self {
24            _size: size,
25            par: vec![None; size],
26            head: vec![0; size],
27            id: vec![0; size],
28            rid: vec![0; size],
29            next: vec![None; size],
30            end: vec![0; size],
31        };
32
33        let mut tr = vec![vec![]; size];
34        for (i, nodes) in tree.nodes.iter().enumerate() {
35            for e in nodes.neighbors() {
36                tr[i].push(e.to());
37            }
38        }
39
40        ret.dfs_sub(&mut tr, root, None, &mut vec![1; size]);
41        ret.dfs_build(&tr, root, &mut 0);
42        ret
43    }
44
45    fn dfs_sub(
46        &mut self,
47        tree: &mut [Vec<usize>],
48        cur: usize,
49        par: Option<usize>,
50        sub: &mut Vec<usize>,
51    ) {
52        self.par[cur] = par;
53        tree[cur].retain(|&x| Some(x) != par);
54
55        let mut t = 0;
56        let n = tree[cur].len();
57        for i in 0..n {
58            let to = tree[cur][i];
59            self.dfs_sub(tree, to, Some(cur), sub);
60            sub[cur] += sub[to];
61            if sub[to] > t {
62                t = sub[to];
63                self.next[cur] = Some(to);
64                tree[cur].swap(i, 0);
65            }
66        }
67    }
68
69    fn dfs_build(&mut self, tree: &[Vec<usize>], cur: usize, index: &mut usize) {
70        self.id[cur] = *index;
71        self.rid[*index] = cur;
72        *index += 1;
73
74        for (i, &to) in tree[cur].iter().enumerate() {
75            self.head[to] = if i == 0 { self.head[cur] } else { to };
76            self.dfs_build(tree, to, index);
77        }
78
79        self.end[cur] = *index;
80    }
81
82    /// 頂点`x`から頂点`y`へ向かうパス上の頂点についてのクエリを扱う。
83    ///
84    /// 演算は可換性を仮定する。
85    ///
86    /// **Time complexity** $O(\log n)$
87    pub fn path_query_vertex<F>(&self, mut x: usize, mut y: usize, mut f: F)
88    where
89        F: FnMut(usize, usize),
90    {
91        loop {
92            if self.id[x] > self.id[y] {
93                (x, y) = (y, x);
94            }
95            f(max(self.id[self.head[y]], self.id[x]), self.id[y] + 1);
96            if self.head[x] == self.head[y] {
97                break;
98            }
99            y = self.par[self.head[y]].unwrap();
100        }
101    }
102
103    /// 頂点`x`から頂点`y`へ向かうパス上の頂点についてのクエリを扱う。
104    ///
105    /// **Time complexity** $O(\log n)$
106    pub fn path_query_vertex_non_commutative<LFunc, RFunc>(
107        &self,
108        x: usize,
109        y: usize,
110        f: LFunc,
111        mut g: RFunc,
112    ) where
113        LFunc: FnMut(usize, usize),
114        RFunc: FnMut(usize, usize),
115    {
116        let w = self.lca(x, y);
117        self.path_query_vertex(x, w, f);
118
119        let (mut x, mut y) = (y, w);
120
121        loop {
122            if self.id[x] > self.id[y] {
123                (x, y) = (y, x);
124            }
125            g(
126                self.id[self.head[y]].max(self.id[x]).max(self.id[w] + 1),
127                self.id[y] + 1,
128            );
129            if self.head[x] == self.head[y] {
130                break;
131            }
132            y = self.par[self.head[y]].unwrap();
133        }
134    }
135
136    /// 頂点`x`から頂点`y`へ向かうパス上の辺についてのクエリを扱う。
137    ///
138    /// **Time complexity** $O(\log n)$
139    pub fn path_query_edge<F>(&self, mut x: usize, mut y: usize, mut f: F)
140    where
141        F: FnMut(usize, usize),
142    {
143        loop {
144            if self.id[x] > self.id[y] {
145                (x, y) = (y, x);
146            }
147            if self.head[x] == self.head[y] {
148                if x != y {
149                    f(self.id[x] + 1, self.id[y] + 1);
150                }
151                break;
152            }
153            f(self.id[self.head[y]], self.id[y] + 1);
154            y = self.par[self.head[y]].unwrap();
155        }
156    }
157
158    /// 頂点`x`の部分木の頂点についてのクエリを扱う。
159    ///
160    /// **Time complexity** $O(1)$
161    pub fn subtree_query_vertex<F>(&self, x: usize, f: F)
162    where
163        F: FnOnce(usize, usize),
164    {
165        f(self.id[x], self.end[x])
166    }
167
168    /// 頂点`x`の部分木の辺についてのクエリを扱う。
169    ///
170    /// **Time complexity** $O(1)$
171    pub fn subtree_query_edge<F>(&self, x: usize, f: F)
172    where
173        F: FnOnce(usize, usize),
174    {
175        f(self.id[x] + 1, self.end[x]);
176    }
177
178    /// **Time complexity** $O(1)$
179    pub fn parent(&self, x: usize) -> Option<usize> {
180        self.par[x]
181    }
182
183    /// **Time complexity** $O(1)$
184    pub fn get_id(&self, x: usize) -> usize {
185        self.id[x]
186    }
187
188    /// **Time complexity** $O(1)$
189    pub fn get_edge_id(&self, u: usize, v: usize) -> Option<usize> {
190        if self.par[u] == Some(v) {
191            Some(self.id[u])
192        } else if self.par[v] == Some(u) {
193            Some(self.id[v])
194        } else {
195            None
196        }
197    }
198
199    /// **Time complexity** $O(\log n)$
200    pub fn lca(&self, mut u: usize, mut v: usize) -> usize {
201        loop {
202            if self.id[u] > self.id[v] {
203                std::mem::swap(&mut u, &mut v);
204            }
205            if self.head[u] == self.head[v] {
206                return u;
207            }
208            v = self.par[self.head[v]].unwrap();
209        }
210    }
211}