haar_lib/math/convolution/
ntt.rs1use std::marker::PhantomData;
3
4use crate::math::prime_mod::*;
5use crate::num::const_modint::*;
6
7#[derive(Clone)]
11pub struct NTT<P: PrimeMod> {
12 _phantom: PhantomData<P>,
13}
14
15impl<P: PrimeMod> NTT<P> {
16 const MAX_POWER: usize = (P::PRIME_NUM as usize - 1).trailing_zeros() as usize;
17 const MAX_SIZE: usize = 1 << Self::MAX_POWER;
18 const BASE: [ConstModInt<P>; 32] = {
19 let mut base = [ConstModInt::<P>::new(0); 32];
20 let mut t = ConstModInt::<P>::new(P::PRIM_ROOT)
21 ._pow((P::PRIME_NUM as u64 - 1) >> (Self::MAX_POWER));
22
23 let mut i = Self::MAX_POWER;
24 base[i] = t;
25 while i > 0 {
26 t = t._mul(t);
27 base[i - 1] = t;
28 i -= 1;
29 }
30
31 base
32 };
33
34 const INV_BASE: [ConstModInt<P>; 32] = {
35 let mut inv_base = [ConstModInt::<P>::new(0); 32];
36 let t = ConstModInt::<P>::new(P::PRIM_ROOT)
37 ._pow((P::PRIME_NUM as u64 - 1) >> (Self::MAX_POWER));
38 let mut s = t._inv();
39
40 let mut i = Self::MAX_POWER;
41 inv_base[i] = s;
42 while i > 0 {
43 s = s._mul(s);
44 inv_base[i - 1] = s;
45 i -= 1;
46 }
47
48 inv_base
49 };
50
51 pub const fn new() -> Self {
53 Self {
54 _phantom: PhantomData,
55 }
56 }
57
58 pub fn ntt(&self, f: &mut [ConstModInt<P>]) {
60 let n = f.len();
61 assert!(n.is_power_of_two() && n <= Self::MAX_SIZE);
62
63 let mut width = n;
64 let mut k = n.trailing_zeros() as usize;
65 while width > 1 {
66 let dw = Self::BASE[k];
67
68 let mut ws = vec![ConstModInt::new(1); width / 2];
69 for i in 1..width / 2 {
70 ws[i] = ws[i - 1] * dw;
71 }
72
73 for a in f.chunks_exact_mut(width) {
74 let (x, y) = a.split_at_mut(width / 2);
75
76 for ((s, t), &w) in x.iter_mut().zip(y.iter_mut()).zip(ws.iter()) {
77 let p = *s + *t;
78 let q = (*s - *t) * w;
79
80 *s = p;
81 *t = q;
82 }
83 }
84
85 k -= 1;
86 width >>= 1;
87 }
88
89 }
97
98 pub fn intt(&self, f: &mut [ConstModInt<P>]) {
100 let n = f.len();
101 assert!(n.is_power_of_two() && n <= Self::MAX_SIZE);
102
103 let mut width = 2;
112 let mut k = 1;
113 while width <= n {
114 let dw = Self::INV_BASE[k];
115
116 let mut ws = vec![ConstModInt::new(1); width / 2];
117 for i in 1..width / 2 {
118 ws[i] = ws[i - 1] * dw;
119 }
120
121 for a in f.chunks_exact_mut(width) {
122 let (x, y) = a.split_at_mut(width / 2);
123
124 for ((s, t), &w) in x.iter_mut().zip(y.iter_mut()).zip(ws.iter()) {
125 let p = *s + *t * w;
126 let q = *s - *t * w;
127
128 *s = p;
129 *t = q;
130 }
131 }
132
133 k += 1;
134 width <<= 1;
135 }
136
137 let t = ConstModInt::new(n as u32).inv();
138 for x in f.iter_mut() {
139 *x *= t;
140 }
141 }
142
143 pub fn convolve<T>(&self, f: Vec<T>, g: Vec<T>) -> Vec<ConstModInt<P>>
147 where
148 T: Into<ConstModInt<P>>,
149 {
150 if f.is_empty() || g.is_empty() {
151 return vec![];
152 }
153
154 let m = f.len() + g.len() - 1;
155 let n = m.next_power_of_two();
156
157 let mut f: Vec<_> = f.into_iter().map(Into::into).collect();
158 let mut g: Vec<_> = g.into_iter().map(Into::into).collect();
159
160 f.resize(n, ConstModInt::new(0));
161 self.ntt(&mut f);
162
163 g.resize(n, ConstModInt::new(0));
164 self.ntt(&mut g);
165
166 for (f, g) in f.iter_mut().zip(g.into_iter()) {
167 *f *= g;
168 }
169 self.intt(&mut f);
170
171 f
172 }
173
174 pub fn convolve_same(&self, mut f: Vec<ConstModInt<P>>) -> Vec<ConstModInt<P>> {
176 if f.is_empty() {
177 return vec![];
178 }
179
180 let n = (f.len() * 2 - 1).next_power_of_two();
181 f.resize(n, ConstModInt::new(0));
182
183 self.ntt(&mut f);
184
185 for x in f.iter_mut() {
186 *x *= *x;
187 }
188
189 self.intt(&mut f);
190 f
191 }
192
193 pub fn max_size(&self) -> usize {
195 Self::MAX_SIZE
196 }
197}
198
199impl<P: PrimeMod> Default for NTT<P> {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205#[cfg(test)]
206mod tests {
207
208 use super::*;
209 use rand::Rng;
210
211 #[test]
212 fn test() {
213 type P = Prime<998244353>;
214
215 let ntt = NTT::<P>::new();
216 let ff = ConstModIntBuilder::<P>::new();
217
218 let mut rng = rand::thread_rng();
219
220 let n = rng.gen_range(1..1000);
221 let m = rng.gen_range(1..1000);
222
223 let a = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..P::PRIME_NUM) as u64))
224 .take(n)
225 .collect::<Vec<_>>();
226 let b = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..P::PRIME_NUM) as u64))
227 .take(m)
228 .collect::<Vec<_>>();
229
230 let res = ntt.convolve(a.clone(), b.clone());
231
232 let mut ans = vec![ConstModInt::new(0); n + m - 1];
233
234 for i in 0..n {
235 for j in 0..m {
236 ans[i + j] += a[i] * b[j];
237 }
238 }
239
240 assert_eq!(&res[..n + m - 1], &ans);
241 }
242}