1use crate::num::const_modint::*;
3
4pub struct NTT<const P: u32, const PRIM_ROOT: u32> {
8 base: Vec<ConstModInt<P>>,
9 inv_base: Vec<ConstModInt<P>>,
10 max_size: usize,
11}
12
13impl<const P: u32, const PRIM_ROOT: u32> NTT<P, PRIM_ROOT> {
14 pub fn new() -> Self {
16 let max_power = (P as usize - 1).trailing_zeros() as usize;
17 let max_size = 1 << max_power;
18
19 let mut base = vec![ConstModInt::new(0); max_power + 1];
20 let mut inv_base = vec![ConstModInt::new(0); max_power + 1];
21
22 let mut t = ConstModInt::new(PRIM_ROOT).pow((P as u64 - 1) >> (max_power));
23 let mut s = t.inv();
24
25 for i in (0..max_power).rev() {
26 t *= t;
27 s *= s;
28 base[i] = t;
29 inv_base[i] = s;
30 }
31
32 Self {
33 base,
34 inv_base,
35 max_size,
36 }
37 }
38
39 pub fn ntt(&self, f: &mut Vec<ConstModInt<P>>) {
41 self.run(f, false);
42 }
43
44 pub fn intt(&self, f: &mut Vec<ConstModInt<P>>) {
46 self.run(f, true);
47 }
48
49 fn run(&self, f: &mut Vec<ConstModInt<P>>, inv: bool) {
50 let n = f.len();
51 assert!(n.is_power_of_two() && n < self.max_size);
52
53 let mut g = vec![ConstModInt::new(0); n];
54
55 let mut b = n >> 1;
56 let mut k = 1;
57 while b > 0 {
58 let dw = if !inv { self.base[k] } else { self.inv_base[k] };
59 let len = n / b;
60
61 let mut w = ConstModInt::new(1);
62
63 for j in 0..len / 2 {
64 for i in 0..b {
65 let even = unsafe { *f.get_unchecked(i + 2 * j * b) };
66 let odd = unsafe { *f.get_unchecked(i + 2 * j * b + b) };
67
68 unsafe {
69 *g.get_unchecked_mut(i + j * b) = even + w * odd;
70 *g.get_unchecked_mut(i + j * b + n / 2) = even - w * odd;
71 }
72 }
73
74 w *= dw;
75 }
76
77 k += 1;
78 b >>= 1;
79
80 std::mem::swap(&mut g, f);
81 }
82
83 if inv {
84 let t = ConstModInt::new(n as u32).inv();
85 for x in f.iter_mut() {
86 *x *= t;
87 }
88 }
89 }
90
91 pub fn convolve(
95 &self,
96 mut f: Vec<ConstModInt<P>>,
97 mut g: Vec<ConstModInt<P>>,
98 ) -> Vec<ConstModInt<P>> {
99 if f.is_empty() || g.is_empty() {
100 return vec![];
101 }
102
103 let m = f.len() + g.len() - 1;
104 let n = m.next_power_of_two();
105
106 f.resize(n, ConstModInt::new(0));
107 self.run(&mut f, false);
108
109 g.resize(n, ConstModInt::new(0));
110 self.run(&mut g, false);
111
112 for (f, g) in f.iter_mut().zip(g.into_iter()) {
113 *f *= g;
114 }
115 self.run(&mut f, true);
116
117 f
118 }
119
120 pub fn convolve_same(&self, mut f: Vec<ConstModInt<P>>) -> Vec<ConstModInt<P>> {
122 if f.is_empty() {
123 return vec![];
124 }
125
126 let n = (f.len() * 2 - 1).next_power_of_two();
127 f.resize(n, ConstModInt::new(0));
128
129 self.run(&mut f, false);
130
131 for x in f.iter_mut() {
132 *x *= *x;
133 }
134
135 self.run(&mut f, true);
136 f
137 }
138
139 pub fn max_size(&self) -> usize {
141 self.max_size
142 }
143}
144
145impl<const P: u32, const PRIM_ROOT: u32> Default for NTT<P, PRIM_ROOT> {
146 fn default() -> Self {
147 Self::new()
148 }
149}
150
151pub type NTT998244353 = NTT<998244353, 3>;
153
154#[cfg(test)]
155mod tests {
156
157 use super::*;
158 use rand::Rng;
159
160 #[test]
161 fn test() {
162 const MOD: u32 = 998244353;
163
164 let ntt = NTT998244353::new();
165 let ff = ConstModIntBuilder::<MOD>;
166
167 let mut rng = rand::thread_rng();
168
169 let n = rng.gen_range(1..1000);
170 let m = rng.gen_range(1..1000);
171
172 let a = (0..n)
173 .map(|_| ff.from_u64(rng.gen_range(0..MOD) as u64))
174 .collect::<Vec<_>>();
175 let b = (0..m)
176 .map(|_| ff.from_u64(rng.gen_range(0..MOD) as u64))
177 .collect::<Vec<_>>();
178
179 let res = ntt.convolve(a.clone(), b.clone());
180
181 let mut ans = vec![ConstModInt::new(0); n + m - 1];
182
183 for i in 0..n {
184 for j in 0..m {
185 ans[i + j] += a[i] * b[j];
186 }
187 }
188
189 assert_eq!(&res[..n + m - 1], &ans);
190 }
191}