haar_lib/ds/
unionfind.rs

1//! 素集合データ構造
2//!
3//! # Problems
4//! - <https://atcoder.jp/contests/abc372/tasks/abc372_e>
5use std::cell::Cell;
6
7/// 素集合の統合と所属の判定ができるデータ構造。
8pub struct UnionFind<'a, T = ()> {
9    n: usize,
10    count: usize,
11    parent: Vec<Cell<usize>>,
12    depth: Vec<usize>,
13    size: Vec<usize>,
14    values: Option<Vec<Option<T>>>,
15    merge: Option<Box<dyn 'a + Fn(T, T) -> T>>,
16}
17
18impl UnionFind<'_, ()> {
19    /// 大きさ`1`の集合を`n`個用意する。
20    pub fn new(n: usize) -> Self {
21        UnionFind {
22            n,
23            count: n,
24            parent: (0..n).map(Cell::new).collect(),
25            depth: vec![1; n],
26            size: vec![1; n],
27            values: None,
28            merge: None,
29        }
30    }
31
32    /// 大きさが`1`の集合を1つ追加する。
33    pub fn extend_one(&mut self) {
34        let k = self.n;
35        self.n += 1;
36        self.count += 1;
37        self.parent.push(Cell::new(k));
38        self.depth.push(1);
39        self.size.push(1);
40    }
41}
42
43impl<'a, T> UnionFind<'a, T> {
44    /// 大きさ`1`の集合を`|values|`個用意する。このとき、各集合`i`に`value[i]`を割り当てる。
45    ///
46    /// `merge`は、集合を統合する際に、新しい集合に割り当てる値を返す。
47    pub fn with_values(values: Vec<T>, merge: Box<impl 'a + Fn(T, T) -> T>) -> Self {
48        let n = values.len();
49        UnionFind {
50            n,
51            count: n,
52            parent: (0..n).map(Cell::new).collect(),
53            depth: vec![1; n],
54            size: vec![1; n],
55            values: Some(values.into_iter().map(Option::Some).collect()),
56            merge: Some(Box::new(merge)),
57        }
58    }
59
60    /// 値`value`を割り当てられた、大きさが`1`の集合を1つ追加する。
61    pub fn extend_one_with_value(&mut self, value: T) {
62        let k = self.n;
63        self.n += 1;
64        self.count += 1;
65        self.parent.push(Cell::new(k));
66        self.depth.push(1);
67        self.size.push(1);
68        self.values.as_mut().unwrap().push(Some(value));
69    }
70
71    /// `i`の属する集合の根を返す。
72    pub fn root_of(&self, i: usize) -> usize {
73        if self.parent[i].get() == i {
74            return i;
75        }
76        let p = self.parent[i].get();
77        self.parent[i].set(self.root_of(p));
78        self.parent[i].get()
79    }
80
81    /// `i`と`j`が同じ集合に属するならば`true`を返す。
82    pub fn is_same(&self, i: usize, j: usize) -> bool {
83        self.root_of(i) == self.root_of(j)
84    }
85
86    /// `i`の属する集合と`j`の属する集合を統合する。
87    pub fn merge(&mut self, i: usize, j: usize) -> usize {
88        let i = self.root_of(i);
89        let j = self.root_of(j);
90
91        if i == j {
92            return i;
93        }
94
95        let (p, c) = if self.depth[i] < self.depth[j] {
96            (j, i)
97        } else {
98            (i, j)
99        };
100
101        self.count -= 1;
102
103        self.parent[c].set(p);
104        self.size[p] += self.size[c];
105        if self.depth[p] == self.depth[c] {
106            self.depth[p] += 1;
107        }
108
109        if let Some(f) = self.merge.as_ref() {
110            let t = f(
111                self.values.as_mut().unwrap()[i].take().unwrap(),
112                self.values.as_mut().unwrap()[j].take().unwrap(),
113            );
114            self.values.as_mut().unwrap()[p] = Some(t);
115        }
116
117        p
118    }
119
120    /// UnionFindの要素数を返す。
121    pub fn len(&self) -> usize {
122        self.n
123    }
124
125    /// UnionFindが要素を持たないとき、`true`を返す。
126    pub fn is_empty(&self) -> bool {
127        self.n == 0
128    }
129
130    /// `i`の属する集合の大きさを返す。
131    pub fn size_of(&self, i: usize) -> usize {
132        let i = self.root_of(i);
133        self.size[i]
134    }
135
136    /// 素集合の個数を返す。
137    pub fn count_groups(&self) -> usize {
138        self.count
139    }
140
141    /// `i`の属する集合のもつ値を返す。
142    pub fn value_of(&self, i: usize) -> Option<&T> {
143        let i = self.root_of(i);
144        self.values.as_ref()?[i].as_ref()
145    }
146
147    /// `i`の属する集合のもつ値への可変参照を返す。
148    pub fn value_mut_of(&mut self, i: usize) -> Option<&mut T> {
149        let i = self.root_of(i);
150        self.values.as_mut()?[i].as_mut()
151    }
152
153    /// 素集合をすべて列挙する。
154    pub fn get_groups(&self) -> Vec<Vec<usize>> {
155        let mut ret = vec![vec![]; self.n];
156
157        for i in 0..self.n {
158            ret[self.root_of(i)].push(i);
159        }
160
161        ret.into_iter().filter(|x| !x.is_empty()).collect()
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::btreeset;
169    use rand::Rng;
170    use std::collections::BTreeSet;
171    use std::iter::FromIterator;
172
173    #[test]
174    fn test() {
175        let n = 100;
176        let q = 50;
177        let mut rng = rand::thread_rng();
178
179        let mut uf = UnionFind::new(n);
180        let mut a = (0..n).map(|i| btreeset![i]).collect::<BTreeSet<_>>();
181
182        for _ in 0..q {
183            let i = rng.gen_range(0..n);
184            let j = rng.gen_range(0..n);
185
186            uf.merge(i, j);
187
188            let mut ai = a.iter().find(|s| s.contains(&i)).unwrap().clone();
189            let aj = a.iter().find(|s| s.contains(&j)).unwrap().clone();
190
191            if ai != aj {
192                a.remove(&ai);
193                a.remove(&aj);
194                ai.extend(aj);
195                a.insert(ai);
196            }
197        }
198
199        for _ in 0..q {
200            let i = rng.gen_range(0..n);
201            let j = rng.gen_range(0..n);
202
203            let ai = a.iter().find(|s| s.contains(&i)).unwrap();
204
205            assert_eq!(uf.is_same(i, j), ai.contains(&j));
206        }
207
208        assert_eq!(
209            BTreeSet::from_iter(
210                uf.get_groups()
211                    .into_iter()
212                    .map(|s| BTreeSet::from_iter(s.into_iter()))
213            ),
214            a
215        );
216    }
217}