haar_lib/math/convolution/
mod.rs

1//! 畳み込み
2pub mod mobius;
3pub mod zeta;
4
5pub mod conv_and;
6pub mod conv_gcd;
7pub mod conv_mul_modp;
8pub mod conv_or;
9pub mod conv_xor;
10pub mod subset_conv;
11
12#[cfg(test)]
13mod tests {
14    use crate::math::gcd_lcm::GcdLcm;
15    use crate::{iter::collect::CollectVec, num::const_modint::*};
16    use rand::Rng;
17
18    use super::conv_and::convolution_and;
19    use super::conv_gcd::convolution_gcd;
20    use super::conv_or::convolution_or;
21    use super::conv_xor::convolution_xor;
22    use super::mobius::*;
23    use super::subset_conv::subset_convolution;
24    use super::zeta::*;
25
26    const M: u32 = 998244353;
27
28    fn is_subset_of(a: usize, b: usize) -> bool {
29        a | b == b
30    }
31
32    #[test]
33    fn test_zeta_mobius() {
34        #![allow(clippy::needless_range_loop)]
35        let mut rng = rand::thread_rng();
36
37        let ff = ConstModIntBuilder::<M>;
38
39        let n = 1 << 10;
40        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
41            .take(n)
42            .collect_vec();
43
44        let mut ans = vec![ff.from_u64(0); n];
45        for i in 0..n {
46            for j in 0..n {
47                if is_subset_of(j, i) {
48                    ans[i] += f[j];
49                }
50            }
51        }
52
53        let mut res = f.clone();
54        fast_zeta_subset(&mut res);
55        assert_eq!(ans, res);
56
57        fast_mobius_subset(&mut res);
58        assert_eq!(f, res);
59
60        let mut ans = vec![ff.from_u64(0); n];
61        for i in 0..n {
62            for j in 0..n {
63                if is_subset_of(i, j) {
64                    ans[i] += f[j];
65                }
66            }
67        }
68
69        let mut res = f.clone();
70        fast_zeta_superset(&mut res);
71        assert_eq!(ans, res);
72
73        fast_mobius_superset(&mut res);
74        assert_eq!(f, res);
75    }
76
77    #[test]
78    fn test_conv_or() {
79        let mut rng = rand::thread_rng();
80
81        let ff = ConstModIntBuilder::<M>;
82
83        let n = 1 << 10;
84        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
85            .take(n)
86            .collect_vec();
87        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
88            .take(n)
89            .collect_vec();
90
91        let mut ans = vec![ff.from_u64(0); n];
92        for i in 0..n {
93            for j in 0..n {
94                ans[i | j] += f[i] * g[j];
95            }
96        }
97
98        let res = convolution_or(f, g);
99
100        assert_eq!(ans, res);
101    }
102
103    #[test]
104    fn test_conv_and() {
105        let mut rng = rand::thread_rng();
106
107        let ff = ConstModIntBuilder::<M>;
108
109        let n = 1 << 10;
110        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
111            .take(n)
112            .collect_vec();
113        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
114            .take(n)
115            .collect_vec();
116
117        let mut ans = vec![ff.from_u64(0); n];
118        for i in 0..n {
119            for j in 0..n {
120                ans[i & j] += f[i] * g[j];
121            }
122        }
123
124        let res = convolution_and(f, g);
125
126        assert_eq!(ans, res);
127    }
128
129    #[test]
130    fn test_conv_xor() {
131        let mut rng = rand::thread_rng();
132
133        let ff = ConstModIntBuilder::<M>;
134
135        let n = 1 << 10;
136        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
137            .take(n)
138            .collect_vec();
139        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
140            .take(n)
141            .collect_vec();
142
143        let mut ans = vec![ff.from_u64(0); n];
144        for i in 0..n {
145            for j in 0..n {
146                ans[i ^ j] += f[i] * g[j];
147            }
148        }
149
150        let res = convolution_xor(f, g, ff);
151
152        assert_eq!(ans, res);
153    }
154
155    #[test]
156    fn test_conv_subset() {
157        let mut rng = rand::thread_rng();
158
159        let ff = ConstModIntBuilder::<M>;
160
161        let n = 1 << 10;
162        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
163            .take(n)
164            .collect_vec();
165        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
166            .take(n)
167            .collect_vec();
168
169        let mut ans = vec![ff.from_u64(0); n];
170        for i in 0..n {
171            for j in 0..n {
172                if i & j == 0 {
173                    ans[i | j] += f[i] * g[j];
174                }
175            }
176        }
177
178        let res = subset_convolution(f, g);
179
180        assert_eq!(ans, res);
181    }
182
183    #[test]
184    fn test_conv_gcd() {
185        let mut rng = rand::thread_rng();
186
187        let ff = ConstModIntBuilder::<M>;
188
189        let n = 1000;
190        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
191            .take(n + 1)
192            .collect_vec();
193        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
194            .take(n + 1)
195            .collect_vec();
196
197        let mut ans = vec![ff.from_u64(0); n + 1];
198        for i in 1..=n {
199            for j in 1..=n {
200                ans[i.gcd(j)] += f[i] * g[j];
201            }
202        }
203
204        let res = convolution_gcd(f, g);
205
206        assert_eq!(ans[1..], res[1..]);
207    }
208}