haar_lib/math/
ntt.rs

1//! 数論変換 (Number Theoretic Transform)
2use crate::num::const_modint::*;
3
4/// 素数$P$上の数論変換 (Number Theoretic Transform)
5///
6/// `PRIM_ROOT`は`P`の原始根。
7pub struct NTT<const P: u32, const PRIM_ROOT: u32> {
8    base: Vec<ConstModInt<P>>,
9    inv_base: Vec<ConstModInt<P>>,
10    max_size: usize,
11}
12
13impl<const P: u32, const PRIM_ROOT: u32> NTT<P, PRIM_ROOT> {
14    /// [`NTT<P, PRIM_ROOT>`]を作る。
15    pub fn new() -> Self {
16        let max_power = (P as usize - 1).trailing_zeros() as usize;
17        let max_size = 1 << max_power;
18
19        let mut base = vec![ConstModInt::new(0); max_power + 1];
20        let mut inv_base = vec![ConstModInt::new(0); max_power + 1];
21
22        let mut t = ConstModInt::new(PRIM_ROOT).pow((P as u64 - 1) >> (max_power));
23        let mut s = t.inv();
24
25        for i in (0..max_power).rev() {
26            t *= t;
27            s *= s;
28            base[i] = t;
29            inv_base[i] = s;
30        }
31
32        Self {
33            base,
34            inv_base,
35            max_size,
36        }
37    }
38
39    /// 数論変換を行う。
40    pub fn ntt(&self, f: &mut Vec<ConstModInt<P>>) {
41        self.run(f, false);
42    }
43
44    /// `ntt`の逆変換を行う。
45    pub fn intt(&self, f: &mut Vec<ConstModInt<P>>) {
46        self.run(f, true);
47    }
48
49    fn run(&self, f: &mut Vec<ConstModInt<P>>, inv: bool) {
50        let n = f.len();
51        assert!(n.is_power_of_two() && n < self.max_size);
52
53        let mut g = vec![ConstModInt::new(0); n];
54
55        let mut b = n >> 1;
56        let mut k = 1;
57        while b > 0 {
58            let dw = if !inv { self.base[k] } else { self.inv_base[k] };
59            let len = n / b;
60
61            let mut w = ConstModInt::new(1);
62
63            for j in 0..len / 2 {
64                for i in 0..b {
65                    let even = unsafe { *f.get_unchecked(i + 2 * j * b) };
66                    let odd = unsafe { *f.get_unchecked(i + 2 * j * b + b) };
67
68                    unsafe {
69                        *g.get_unchecked_mut(i + j * b) = even + w * odd;
70                        *g.get_unchecked_mut(i + j * b + n / 2) = even - w * odd;
71                    }
72                }
73
74                w *= dw;
75            }
76
77            k += 1;
78            b >>= 1;
79
80            std::mem::swap(&mut g, f);
81        }
82
83        if inv {
84            let t = ConstModInt::new(n as u32).inv();
85            for x in f.iter_mut() {
86                *x *= t;
87            }
88        }
89    }
90
91    /// 2つの`Vec`を畳み込む。
92    ///
93    /// $(f \ast g)(k) = \sum_{k = i + j} f(i) \times g(j)$
94    pub fn convolve(
95        &self,
96        mut f: Vec<ConstModInt<P>>,
97        mut g: Vec<ConstModInt<P>>,
98    ) -> Vec<ConstModInt<P>> {
99        if f.is_empty() || g.is_empty() {
100            return vec![];
101        }
102
103        let m = f.len() + g.len() - 1;
104        let n = m.next_power_of_two();
105
106        f.resize(n, ConstModInt::new(0));
107        self.run(&mut f, false);
108
109        g.resize(n, ConstModInt::new(0));
110        self.run(&mut g, false);
111
112        for (f, g) in f.iter_mut().zip(g.into_iter()) {
113            *f *= g;
114        }
115        self.run(&mut f, true);
116
117        f
118    }
119
120    /// `convolve(f.clone(), f)`と同等。
121    pub fn convolve_same(&self, mut f: Vec<ConstModInt<P>>) -> Vec<ConstModInt<P>> {
122        if f.is_empty() {
123            return vec![];
124        }
125
126        let n = (f.len() * 2 - 1).next_power_of_two();
127        f.resize(n, ConstModInt::new(0));
128
129        self.run(&mut f, false);
130
131        for x in f.iter_mut() {
132            *x *= *x;
133        }
134
135        self.run(&mut f, true);
136        f
137    }
138
139    /// NTTで変換可能な配列の最大長を返す。
140    pub fn max_size(&self) -> usize {
141        self.max_size
142    }
143}
144
145impl<const P: u32, const PRIM_ROOT: u32> Default for NTT<P, PRIM_ROOT> {
146    fn default() -> Self {
147        Self::new()
148    }
149}
150
151/// $\mod 998244353 (= 2^{23} * 7 * 17 + 1)$上の`NTT`
152pub type NTT998244353 = NTT<998244353, 3>;
153
154#[cfg(test)]
155mod tests {
156
157    use super::*;
158    use rand::Rng;
159
160    #[test]
161    fn test() {
162        const MOD: u32 = 998244353;
163
164        let ntt = NTT998244353::new();
165        let ff = ConstModIntBuilder::<MOD>;
166
167        let mut rng = rand::thread_rng();
168
169        let n = rng.gen_range(1..1000);
170        let m = rng.gen_range(1..1000);
171
172        let a = (0..n)
173            .map(|_| ff.from_u64(rng.gen_range(0..MOD) as u64))
174            .collect::<Vec<_>>();
175        let b = (0..m)
176            .map(|_| ff.from_u64(rng.gen_range(0..MOD) as u64))
177            .collect::<Vec<_>>();
178
179        let res = ntt.convolve(a.clone(), b.clone());
180
181        let mut ans = vec![ConstModInt::new(0); n + m - 1];
182
183        for i in 0..n {
184            for j in 0..m {
185                ans[i + j] += a[i] * b[j];
186            }
187        }
188
189        assert_eq!(&res[..n + m - 1], &ans);
190    }
191}