haar_lib/math/convolution/
mod.rs

1//! 畳み込み
2pub mod div_mul_transform;
3pub mod mobius;
4pub mod zeta;
5
6pub mod conv_and_or;
7pub mod conv_gcd_lcm;
8pub mod conv_mul_mod2n;
9pub mod conv_mul_modp;
10pub mod conv_xor;
11pub mod min_plus_conv_convex;
12pub mod subset_conv;
13
14pub mod ntt;
15
16#[cfg(test)]
17mod tests {
18    use crate::math::gcd_lcm::GcdLcm;
19    use crate::math::prime_mod::{Prime, PrimeMod};
20    use crate::{iter::collect::CollectVec, num::const_modint::*};
21    use rand::Rng;
22
23    use super::conv_and_or::{convolution_and, convolution_or};
24    use super::conv_gcd_lcm::convolution_gcd;
25    use super::conv_xor::convolution_xor;
26    use super::mobius::*;
27    use super::subset_conv::subset_convolution;
28    use super::zeta::*;
29
30    type M = Prime<998244353>;
31
32    fn is_subset_of(a: usize, b: usize) -> bool {
33        a | b == b
34    }
35
36    #[test]
37    fn test_zeta_mobius() {
38        #![allow(clippy::needless_range_loop)]
39        let mut rng = rand::thread_rng();
40
41        let ff = ConstModIntBuilder::<M>::new();
42
43        let n = 1 << 10;
44        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M::PRIME_NUM) as u64))
45            .take(n)
46            .collect_vec();
47
48        let mut ans = vec![ff.from_u64(0); n];
49        for i in 0..n {
50            for j in 0..n {
51                if is_subset_of(j, i) {
52                    ans[i] += f[j];
53                }
54            }
55        }
56
57        let mut res = f.clone();
58        fast_zeta_subset(&mut res);
59        assert_eq!(ans, res);
60
61        fast_mobius_subset(&mut res);
62        assert_eq!(f, res);
63
64        let mut ans = vec![ff.from_u64(0); n];
65        for i in 0..n {
66            for j in 0..n {
67                if is_subset_of(i, j) {
68                    ans[i] += f[j];
69                }
70            }
71        }
72
73        let mut res = f.clone();
74        fast_zeta_superset(&mut res);
75        assert_eq!(ans, res);
76
77        fast_mobius_superset(&mut res);
78        assert_eq!(f, res);
79    }
80
81    #[test]
82    fn test_conv_or() {
83        let mut rng = rand::thread_rng();
84
85        let ff = ConstModIntBuilder::<M>::new();
86
87        let n = 1 << 10;
88        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M::PRIME_NUM) as u64))
89            .take(n)
90            .collect_vec();
91        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M::PRIME_NUM) as u64))
92            .take(n)
93            .collect_vec();
94
95        let mut ans = vec![ff.from_u64(0); n];
96        for i in 0..n {
97            for j in 0..n {
98                ans[i | j] += f[i] * g[j];
99            }
100        }
101
102        let res = convolution_or(f, g);
103
104        assert_eq!(ans, res);
105    }
106
107    #[test]
108    fn test_conv_and() {
109        let mut rng = rand::thread_rng();
110
111        let ff = ConstModIntBuilder::<M>::new();
112
113        let n = 1 << 10;
114        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M::PRIME_NUM) as u64))
115            .take(n)
116            .collect_vec();
117        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M::PRIME_NUM) as u64))
118            .take(n)
119            .collect_vec();
120
121        let mut ans = vec![ff.from_u64(0); n];
122        for i in 0..n {
123            for j in 0..n {
124                ans[i & j] += f[i] * g[j];
125            }
126        }
127
128        let res = convolution_and(f, g);
129
130        assert_eq!(ans, res);
131    }
132
133    #[test]
134    fn test_conv_xor() {
135        let mut rng = rand::thread_rng();
136
137        let ff = ConstModIntBuilder::<M>::new();
138
139        let n = 1 << 10;
140        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M::PRIME_NUM) as u64))
141            .take(n)
142            .collect_vec();
143        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M::PRIME_NUM) as u64))
144            .take(n)
145            .collect_vec();
146
147        let mut ans = vec![ff.from_u64(0); n];
148        for i in 0..n {
149            for j in 0..n {
150                ans[i ^ j] += f[i] * g[j];
151            }
152        }
153
154        let res = convolution_xor(f, g, ff);
155
156        assert_eq!(ans, res);
157    }
158
159    #[test]
160    fn test_conv_subset() {
161        let mut rng = rand::thread_rng();
162
163        let ff = ConstModIntBuilder::<M>::new();
164
165        let n = 1 << 10;
166        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M::PRIME_NUM) as u64))
167            .take(n)
168            .collect_vec();
169        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M::PRIME_NUM) as u64))
170            .take(n)
171            .collect_vec();
172
173        let mut ans = vec![ff.from_u64(0); n];
174        for i in 0..n {
175            for j in 0..n {
176                if i & j == 0 {
177                    ans[i | j] += f[i] * g[j];
178                }
179            }
180        }
181
182        let res = subset_convolution(f, g);
183
184        assert_eq!(ans, res);
185    }
186
187    #[test]
188    fn test_conv_gcd() {
189        let mut rng = rand::thread_rng();
190
191        let ff = ConstModIntBuilder::<M>::new();
192
193        let n = 1000;
194        let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M::PRIME_NUM) as u64))
195            .take(n + 1)
196            .collect_vec();
197        let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M::PRIME_NUM) as u64))
198            .take(n + 1)
199            .collect_vec();
200
201        let mut ans = vec![ff.from_u64(0); n + 1];
202        for i in 1..=n {
203            for j in 1..=n {
204                ans[i.gcd(j)] += f[i] * g[j];
205            }
206        }
207
208        let res = convolution_gcd(f, g);
209
210        assert_eq!(ans[1..], res[1..]);
211    }
212}