haar_lib/ds/
unionfind.rs

1//! 素集合データ構造
2//!
3//! # Problems
4//! - <https://atcoder.jp/contests/abc372/tasks/abc372_e>
5use std::cell::Cell;
6
7/// 素集合の統合と所属の判定ができるデータ構造。
8pub struct UnionFind<'a, T = ()> {
9    n: usize,
10    count: usize,
11    parent: Vec<Cell<usize>>,
12    depth: Vec<usize>,
13    size: Vec<usize>,
14    values: Option<Vec<Option<T>>>,
15    merge: Option<Box<dyn 'a + Fn(T, T) -> T>>,
16}
17
18impl UnionFind<'_, ()> {
19    /// 大きさ`1`の集合を`n`個用意する。
20    pub fn new(n: usize) -> Self {
21        UnionFind {
22            n,
23            count: n,
24            parent: (0..n).map(Cell::new).collect(),
25            depth: vec![1; n],
26            size: vec![1; n],
27            values: None,
28            merge: None,
29        }
30    }
31}
32
33impl<'a, T> UnionFind<'a, T> {
34    /// 大きさ`1`の集合を`|values|`個用意する。このとき、各集合`i`に`value[i]`を割り当てる。
35    ///
36    /// `merge`は、集合を統合する際に、新しい集合に割り当てる値を返す。
37    pub fn with_values(values: Vec<T>, merge: Box<impl 'a + Fn(T, T) -> T>) -> Self {
38        let n = values.len();
39        UnionFind {
40            n,
41            count: n,
42            parent: (0..n).map(Cell::new).collect(),
43            depth: vec![1; n],
44            size: vec![1; n],
45            values: Some(values.into_iter().map(Option::Some).collect()),
46            merge: Some(Box::new(merge)),
47        }
48    }
49
50    /// `i`の属する集合の根を返す。
51    pub fn root_of(&self, i: usize) -> usize {
52        if self.parent[i].get() == i {
53            return i;
54        }
55        let p = self.parent[i].get();
56        self.parent[i].set(self.root_of(p));
57        self.parent[i].get()
58    }
59
60    /// `i`と`j`が同じ集合に属するならば`true`を返す。
61    pub fn is_same(&self, i: usize, j: usize) -> bool {
62        self.root_of(i) == self.root_of(j)
63    }
64
65    /// `i`の属する集合と`j`の属する集合を統合する。
66    pub fn merge(&mut self, i: usize, j: usize) -> usize {
67        let i = self.root_of(i);
68        let j = self.root_of(j);
69
70        if i == j {
71            return i;
72        }
73
74        let (p, c) = if self.depth[i] < self.depth[j] {
75            (j, i)
76        } else {
77            (i, j)
78        };
79
80        self.count -= 1;
81
82        self.parent[c].set(p);
83        self.size[p] += self.size[c];
84        if self.depth[p] == self.depth[c] {
85            self.depth[p] += 1;
86        }
87
88        if let Some(f) = self.merge.as_ref() {
89            let t = f(
90                self.values.as_mut().unwrap()[p].take().unwrap(),
91                self.values.as_mut().unwrap()[c].take().unwrap(),
92            );
93            self.values.as_mut().unwrap()[p] = Some(t);
94        }
95
96        p
97    }
98
99    /// `i`の属する集合の大きさを返す。
100    pub fn size_of(&self, i: usize) -> usize {
101        let i = self.root_of(i);
102        self.size[i]
103    }
104
105    /// 素集合の個数を返す。
106    pub fn count_groups(&self) -> usize {
107        self.count
108    }
109
110    /// `i`の属する集合のもつ値を返す。
111    pub fn value_of(&self, i: usize) -> Option<&T> {
112        let i = self.root_of(i);
113        self.values.as_ref()?[i].as_ref()
114    }
115
116    /// 素集合をすべて列挙する。
117    pub fn get_groups(&self) -> Vec<Vec<usize>> {
118        let mut ret = vec![vec![]; self.n];
119
120        for i in 0..self.n {
121            ret[self.root_of(i)].push(i);
122        }
123
124        ret.into_iter().filter(|x| !x.is_empty()).collect()
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::btreeset;
132    use rand::Rng;
133    use std::collections::BTreeSet;
134    use std::iter::FromIterator;
135
136    #[test]
137    fn test() {
138        let n = 100;
139        let q = 50;
140        let mut rng = rand::thread_rng();
141
142        let mut uf = UnionFind::new(n);
143        let mut a = (0..n).map(|i| btreeset![i]).collect::<BTreeSet<_>>();
144
145        for _ in 0..q {
146            let i = rng.gen_range(0..n);
147            let j = rng.gen_range(0..n);
148
149            uf.merge(i, j);
150
151            let mut ai = a.iter().find(|s| s.contains(&i)).unwrap().clone();
152            let aj = a.iter().find(|s| s.contains(&j)).unwrap().clone();
153
154            if ai != aj {
155                a.remove(&ai);
156                a.remove(&aj);
157                ai.extend(aj);
158                a.insert(ai);
159            }
160        }
161
162        for _ in 0..q {
163            let i = rng.gen_range(0..n);
164            let j = rng.gen_range(0..n);
165
166            let ai = a.iter().find(|s| s.contains(&i)).unwrap();
167
168            assert_eq!(uf.is_same(i, j), ai.contains(&j));
169        }
170
171        assert_eq!(
172            BTreeSet::from_iter(
173                uf.get_groups()
174                    .into_iter()
175                    .map(|s| BTreeSet::from_iter(s.into_iter()))
176            ),
177            a
178        );
179    }
180}