haar_lib/math/convolution/
mod.rs

1//! 畳み込み
2pub mod div_mul_transform;
3pub mod mobius;
4pub mod zeta;
5
6pub mod conv_and;
7pub mod conv_gcd;
8pub mod conv_lcm;
9pub mod conv_mul_modp;
10pub mod conv_or;
11pub mod conv_xor;
12pub mod subset_conv;
13
14#[cfg(test)]
15mod tests {
16    use crate::math::gcd_lcm::GcdLcm;
17    use crate::{iter::collect::CollectVec, num::const_modint::*};
18    use rand::Rng;
19
20    use super::conv_and::convolution_and;
21    use super::conv_gcd::convolution_gcd;
22    use super::conv_or::convolution_or;
23    use super::conv_xor::convolution_xor;
24    use super::mobius::*;
25    use super::subset_conv::subset_convolution;
26    use super::zeta::*;
27
28    const M: u32 = 998244353;
29
30    fn is_subset_of(a: usize, b: usize) -> bool {
31        a | b == b
32    }
33
34    #[test]
35    fn test_zeta_mobius() {
36        #![allow(clippy::needless_range_loop)]
37        let mut rng = rand::thread_rng();
38
39        let ff = ConstModIntBuilder::<M>;
40
41        let n = 1 << 10;
42        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
43            .take(n)
44            .collect_vec();
45
46        let mut ans = vec![ff.from_u64(0); n];
47        for i in 0..n {
48            for j in 0..n {
49                if is_subset_of(j, i) {
50                    ans[i] += f[j];
51                }
52            }
53        }
54
55        let mut res = f.clone();
56        fast_zeta_subset(&mut res);
57        assert_eq!(ans, res);
58
59        fast_mobius_subset(&mut res);
60        assert_eq!(f, res);
61
62        let mut ans = vec![ff.from_u64(0); n];
63        for i in 0..n {
64            for j in 0..n {
65                if is_subset_of(i, j) {
66                    ans[i] += f[j];
67                }
68            }
69        }
70
71        let mut res = f.clone();
72        fast_zeta_superset(&mut res);
73        assert_eq!(ans, res);
74
75        fast_mobius_superset(&mut res);
76        assert_eq!(f, res);
77    }
78
79    #[test]
80    fn test_conv_or() {
81        let mut rng = rand::thread_rng();
82
83        let ff = ConstModIntBuilder::<M>;
84
85        let n = 1 << 10;
86        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
87            .take(n)
88            .collect_vec();
89        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
90            .take(n)
91            .collect_vec();
92
93        let mut ans = vec![ff.from_u64(0); n];
94        for i in 0..n {
95            for j in 0..n {
96                ans[i | j] += f[i] * g[j];
97            }
98        }
99
100        let res = convolution_or(f, g);
101
102        assert_eq!(ans, res);
103    }
104
105    #[test]
106    fn test_conv_and() {
107        let mut rng = rand::thread_rng();
108
109        let ff = ConstModIntBuilder::<M>;
110
111        let n = 1 << 10;
112        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
113            .take(n)
114            .collect_vec();
115        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
116            .take(n)
117            .collect_vec();
118
119        let mut ans = vec![ff.from_u64(0); n];
120        for i in 0..n {
121            for j in 0..n {
122                ans[i & j] += f[i] * g[j];
123            }
124        }
125
126        let res = convolution_and(f, g);
127
128        assert_eq!(ans, res);
129    }
130
131    #[test]
132    fn test_conv_xor() {
133        let mut rng = rand::thread_rng();
134
135        let ff = ConstModIntBuilder::<M>;
136
137        let n = 1 << 10;
138        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
139            .take(n)
140            .collect_vec();
141        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
142            .take(n)
143            .collect_vec();
144
145        let mut ans = vec![ff.from_u64(0); n];
146        for i in 0..n {
147            for j in 0..n {
148                ans[i ^ j] += f[i] * g[j];
149            }
150        }
151
152        let res = convolution_xor(f, g, ff);
153
154        assert_eq!(ans, res);
155    }
156
157    #[test]
158    fn test_conv_subset() {
159        let mut rng = rand::thread_rng();
160
161        let ff = ConstModIntBuilder::<M>;
162
163        let n = 1 << 10;
164        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
165            .take(n)
166            .collect_vec();
167        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
168            .take(n)
169            .collect_vec();
170
171        let mut ans = vec![ff.from_u64(0); n];
172        for i in 0..n {
173            for j in 0..n {
174                if i & j == 0 {
175                    ans[i | j] += f[i] * g[j];
176                }
177            }
178        }
179
180        let res = subset_convolution(f, g);
181
182        assert_eq!(ans, res);
183    }
184
185    #[test]
186    fn test_conv_gcd() {
187        let mut rng = rand::thread_rng();
188
189        let ff = ConstModIntBuilder::<M>;
190
191        let n = 1000;
192        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
193            .take(n + 1)
194            .collect_vec();
195        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
196            .take(n + 1)
197            .collect_vec();
198
199        let mut ans = vec![ff.from_u64(0); n + 1];
200        for i in 1..=n {
201            for j in 1..=n {
202                ans[i.gcd(j)] += f[i] * g[j];
203            }
204        }
205
206        let res = convolution_gcd(f, g);
207
208        assert_eq!(ans[1..], res[1..]);
209    }
210}