haar_lib/math/convolution/
mod.rs1pub mod mobius;
3pub mod zeta;
4
5pub mod conv_and;
6pub mod conv_gcd;
7pub mod conv_mul_modp;
8pub mod conv_or;
9pub mod conv_xor;
10pub mod subset_conv;
11
12#[cfg(test)]
13mod tests {
14 use crate::math::gcd_lcm::GcdLcm;
15 use crate::{iter::collect::CollectVec, num::const_modint::*};
16 use rand::Rng;
17
18 use super::conv_and::convolution_and;
19 use super::conv_gcd::convolution_gcd;
20 use super::conv_or::convolution_or;
21 use super::conv_xor::convolution_xor;
22 use super::mobius::*;
23 use super::subset_conv::subset_convolution;
24 use super::zeta::*;
25
26 const M: u32 = 998244353;
27
28 fn is_subset_of(a: usize, b: usize) -> bool {
29 a | b == b
30 }
31
32 #[test]
33 fn test_zeta_mobius() {
34 #![allow(clippy::needless_range_loop)]
35 let mut rng = rand::thread_rng();
36
37 let ff = ConstModIntBuilder::<M>;
38
39 let n = 1 << 10;
40 let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
41 .take(n)
42 .collect_vec();
43
44 let mut ans = vec![ff.from_u64(0); n];
45 for i in 0..n {
46 for j in 0..n {
47 if is_subset_of(j, i) {
48 ans[i] += f[j];
49 }
50 }
51 }
52
53 let mut res = f.clone();
54 fast_zeta_subset(&mut res);
55 assert_eq!(ans, res);
56
57 fast_mobius_subset(&mut res);
58 assert_eq!(f, res);
59
60 let mut ans = vec![ff.from_u64(0); n];
61 for i in 0..n {
62 for j in 0..n {
63 if is_subset_of(i, j) {
64 ans[i] += f[j];
65 }
66 }
67 }
68
69 let mut res = f.clone();
70 fast_zeta_superset(&mut res);
71 assert_eq!(ans, res);
72
73 fast_mobius_superset(&mut res);
74 assert_eq!(f, res);
75 }
76
77 #[test]
78 fn test_conv_or() {
79 let mut rng = rand::thread_rng();
80
81 let ff = ConstModIntBuilder::<M>;
82
83 let n = 1 << 10;
84 let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
85 .take(n)
86 .collect_vec();
87 let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
88 .take(n)
89 .collect_vec();
90
91 let mut ans = vec![ff.from_u64(0); n];
92 for i in 0..n {
93 for j in 0..n {
94 ans[i | j] += f[i] * g[j];
95 }
96 }
97
98 let res = convolution_or(f, g);
99
100 assert_eq!(ans, res);
101 }
102
103 #[test]
104 fn test_conv_and() {
105 let mut rng = rand::thread_rng();
106
107 let ff = ConstModIntBuilder::<M>;
108
109 let n = 1 << 10;
110 let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
111 .take(n)
112 .collect_vec();
113 let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
114 .take(n)
115 .collect_vec();
116
117 let mut ans = vec![ff.from_u64(0); n];
118 for i in 0..n {
119 for j in 0..n {
120 ans[i & j] += f[i] * g[j];
121 }
122 }
123
124 let res = convolution_and(f, g);
125
126 assert_eq!(ans, res);
127 }
128
129 #[test]
130 fn test_conv_xor() {
131 let mut rng = rand::thread_rng();
132
133 let ff = ConstModIntBuilder::<M>;
134
135 let n = 1 << 10;
136 let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
137 .take(n)
138 .collect_vec();
139 let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
140 .take(n)
141 .collect_vec();
142
143 let mut ans = vec![ff.from_u64(0); n];
144 for i in 0..n {
145 for j in 0..n {
146 ans[i ^ j] += f[i] * g[j];
147 }
148 }
149
150 let res = convolution_xor(f, g, ff);
151
152 assert_eq!(ans, res);
153 }
154
155 #[test]
156 fn test_conv_subset() {
157 let mut rng = rand::thread_rng();
158
159 let ff = ConstModIntBuilder::<M>;
160
161 let n = 1 << 10;
162 let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
163 .take(n)
164 .collect_vec();
165 let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
166 .take(n)
167 .collect_vec();
168
169 let mut ans = vec![ff.from_u64(0); n];
170 for i in 0..n {
171 for j in 0..n {
172 if i & j == 0 {
173 ans[i | j] += f[i] * g[j];
174 }
175 }
176 }
177
178 let res = subset_convolution(f, g);
179
180 assert_eq!(ans, res);
181 }
182
183 #[test]
184 fn test_conv_gcd() {
185 let mut rng = rand::thread_rng();
186
187 let ff = ConstModIntBuilder::<M>;
188
189 let n = 1000;
190 let f = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
191 .take(n + 1)
192 .collect_vec();
193 let g = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
194 .take(n + 1)
195 .collect_vec();
196
197 let mut ans = vec![ff.from_u64(0); n + 1];
198 for i in 1..=n {
199 for j in 1..=n {
200 ans[i.gcd(j)] += f[i] * g[j];
201 }
202 }
203
204 let res = convolution_gcd(f, g);
205
206 assert_eq!(ans[1..], res[1..]);
207 }
208}