1use std::cell::Cell;
6
7pub struct UnionFind<'a, T = ()> {
9 n: usize,
10 count: usize,
11 parent: Vec<Cell<usize>>,
12 depth: Vec<usize>,
13 size: Vec<usize>,
14 values: Option<Vec<Option<T>>>,
15 merge: Option<Box<dyn 'a + Fn(T, T) -> T>>,
16}
17
18impl UnionFind<'_, ()> {
19 pub fn new(n: usize) -> Self {
21 UnionFind {
22 n,
23 count: n,
24 parent: (0..n).map(Cell::new).collect(),
25 depth: vec![1; n],
26 size: vec![1; n],
27 values: None,
28 merge: None,
29 }
30 }
31
32 pub fn extend_one(&mut self) {
34 let k = self.n;
35 self.n += 1;
36 self.count += 1;
37 self.parent.push(Cell::new(k));
38 self.depth.push(1);
39 self.size.push(1);
40 }
41}
42
43impl<'a, T> UnionFind<'a, T> {
44 pub fn with_values(values: Vec<T>, merge: Box<impl 'a + Fn(T, T) -> T>) -> Self {
48 let n = values.len();
49 UnionFind {
50 n,
51 count: n,
52 parent: (0..n).map(Cell::new).collect(),
53 depth: vec![1; n],
54 size: vec![1; n],
55 values: Some(values.into_iter().map(Option::Some).collect()),
56 merge: Some(Box::new(merge)),
57 }
58 }
59
60 pub fn extend_one_with_value(&mut self, value: T) {
62 let k = self.n;
63 self.n += 1;
64 self.count += 1;
65 self.parent.push(Cell::new(k));
66 self.depth.push(1);
67 self.size.push(1);
68 self.values.as_mut().unwrap().push(Some(value));
69 }
70
71 pub fn root_of(&self, i: usize) -> usize {
73 if self.parent[i].get() == i {
74 return i;
75 }
76 let p = self.parent[i].get();
77 self.parent[i].set(self.root_of(p));
78 self.parent[i].get()
79 }
80
81 pub fn is_same(&self, i: usize, j: usize) -> bool {
83 self.root_of(i) == self.root_of(j)
84 }
85
86 pub fn merge(&mut self, i: usize, j: usize) -> usize {
88 let i = self.root_of(i);
89 let j = self.root_of(j);
90
91 if i == j {
92 return i;
93 }
94
95 let (p, c) = if self.depth[i] < self.depth[j] {
96 (j, i)
97 } else {
98 (i, j)
99 };
100
101 self.count -= 1;
102
103 self.parent[c].set(p);
104 self.size[p] += self.size[c];
105 if self.depth[p] == self.depth[c] {
106 self.depth[p] += 1;
107 }
108
109 if let Some(f) = self.merge.as_ref() {
110 let t = f(
111 self.values.as_mut().unwrap()[i].take().unwrap(),
112 self.values.as_mut().unwrap()[j].take().unwrap(),
113 );
114 self.values.as_mut().unwrap()[p] = Some(t);
115 }
116
117 p
118 }
119
120 pub fn len(&self) -> usize {
122 self.n
123 }
124
125 pub fn is_empty(&self) -> bool {
127 self.n == 0
128 }
129
130 pub fn size_of(&self, i: usize) -> usize {
132 let i = self.root_of(i);
133 self.size[i]
134 }
135
136 pub fn count_groups(&self) -> usize {
138 self.count
139 }
140
141 pub fn value_of(&self, i: usize) -> Option<&T> {
143 let i = self.root_of(i);
144 self.values.as_ref()?[i].as_ref()
145 }
146
147 pub fn value_mut_of(&mut self, i: usize) -> Option<&mut T> {
149 let i = self.root_of(i);
150 self.values.as_mut()?[i].as_mut()
151 }
152
153 pub fn get_groups(&self) -> Vec<Vec<usize>> {
155 let mut ret = vec![vec![]; self.n];
156
157 for i in 0..self.n {
158 ret[self.root_of(i)].push(i);
159 }
160
161 ret.into_iter().filter(|x| !x.is_empty()).collect()
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::btreeset;
169 use rand::Rng;
170 use std::collections::BTreeSet;
171 use std::iter::FromIterator;
172
173 #[test]
174 fn test() {
175 let n = 100;
176 let q = 50;
177 let mut rng = rand::thread_rng();
178
179 let mut uf = UnionFind::new(n);
180 let mut a = (0..n).map(|i| btreeset![i]).collect::<BTreeSet<_>>();
181
182 for _ in 0..q {
183 let i = rng.gen_range(0..n);
184 let j = rng.gen_range(0..n);
185
186 uf.merge(i, j);
187
188 let mut ai = a.iter().find(|s| s.contains(&i)).unwrap().clone();
189 let aj = a.iter().find(|s| s.contains(&j)).unwrap().clone();
190
191 if ai != aj {
192 a.remove(&ai);
193 a.remove(&aj);
194 ai.extend(aj);
195 a.insert(ai);
196 }
197 }
198
199 for _ in 0..q {
200 let i = rng.gen_range(0..n);
201 let j = rng.gen_range(0..n);
202
203 let ai = a.iter().find(|s| s.contains(&i)).unwrap();
204
205 assert_eq!(uf.is_same(i, j), ai.contains(&j));
206 }
207
208 assert_eq!(
209 BTreeSet::from_iter(
210 uf.get_groups()
211 .into_iter()
212 .map(|s| BTreeSet::from_iter(s.into_iter()))
213 ),
214 a
215 );
216 }
217}