haar_lib/math/convolution/
ntt.rs

1//! 数論変換 (Number Theoretic Transform)
2use std::marker::PhantomData;
3
4use crate::math::prime_mod::*;
5use crate::num::const_modint::*;
6
7/// 素数$P$上の数論変換 (Number Theoretic Transform)
8///
9/// `PRIM_ROOT`は`P`の原始根。
10#[derive(Clone)]
11pub struct NTT<P: PrimeMod> {
12    _phantom: PhantomData<P>,
13}
14
15impl<P: PrimeMod> NTT<P> {
16    const MAX_POWER: usize = (P::PRIME_NUM as usize - 1).trailing_zeros() as usize;
17    const MAX_SIZE: usize = 1 << Self::MAX_POWER;
18    const BASE: [ConstModInt<P>; 32] = {
19        let mut base = [ConstModInt::<P>::new(0); 32];
20        let mut t = ConstModInt::<P>::new(P::PRIM_ROOT)
21            ._pow((P::PRIME_NUM as u64 - 1) >> (Self::MAX_POWER));
22
23        let mut i = Self::MAX_POWER;
24        base[i] = t;
25        while i > 0 {
26            t = t._mul(t);
27            base[i - 1] = t;
28            i -= 1;
29        }
30
31        base
32    };
33
34    const INV_BASE: [ConstModInt<P>; 32] = {
35        let mut inv_base = [ConstModInt::<P>::new(0); 32];
36        let t = ConstModInt::<P>::new(P::PRIM_ROOT)
37            ._pow((P::PRIME_NUM as u64 - 1) >> (Self::MAX_POWER));
38        let mut s = t._inv();
39
40        let mut i = Self::MAX_POWER;
41        inv_base[i] = s;
42        while i > 0 {
43            s = s._mul(s);
44            inv_base[i - 1] = s;
45            i -= 1;
46        }
47
48        inv_base
49    };
50
51    /// [`NTT<P, PRIM_ROOT>`]を作る。
52    pub const fn new() -> Self {
53        Self {
54            _phantom: PhantomData,
55        }
56    }
57
58    /// 数論変換を行う。
59    pub fn ntt(&self, f: &mut [ConstModInt<P>]) {
60        let n = f.len();
61        assert!(n.is_power_of_two() && n <= Self::MAX_SIZE);
62
63        let mut width = n;
64        let mut k = n.trailing_zeros() as usize;
65        while width > 1 {
66            let dw = Self::BASE[k];
67
68            let mut ws = vec![ConstModInt::new(1); width / 2];
69            for i in 1..width / 2 {
70                ws[i] = ws[i - 1] * dw;
71            }
72
73            for a in f.chunks_exact_mut(width) {
74                let (x, y) = a.split_at_mut(width / 2);
75
76                for ((s, t), &w) in x.iter_mut().zip(y.iter_mut()).zip(ws.iter()) {
77                    let p = *s + *t;
78                    let q = (*s - *t) * w;
79
80                    *s = p;
81                    *t = q;
82                }
83            }
84
85            k -= 1;
86            width >>= 1;
87        }
88
89        // let p = size_of::<usize>() * 8 - n.trailing_zeros() as usize;
90        // let mut g = vec![ConstModInt::new(0); n];
91        // for i in 0..n {
92        //     let j = i.reverse_bits() >> p;
93        //     g[j] = f[i];
94        // }
95        // std::mem::swap(f, &mut g);
96    }
97
98    /// `ntt`の逆変換を行う。
99    pub fn intt(&self, f: &mut [ConstModInt<P>]) {
100        let n = f.len();
101        assert!(n.is_power_of_two() && n <= Self::MAX_SIZE);
102
103        // let p = size_of::<usize>() * 8 - n.trailing_zeros() as usize;
104        // let mut g = vec![ConstModInt::new(0); n];
105        // for i in 0..n {
106        //     let j = i.reverse_bits() >> p;
107        //     g[j] = f[i];
108        // }
109        // std::mem::swap(f, &mut g);
110
111        let mut width = 2;
112        let mut k = 1;
113        while width <= n {
114            let dw = Self::INV_BASE[k];
115
116            let mut ws = vec![ConstModInt::new(1); width / 2];
117            for i in 1..width / 2 {
118                ws[i] = ws[i - 1] * dw;
119            }
120
121            for a in f.chunks_exact_mut(width) {
122                let (x, y) = a.split_at_mut(width / 2);
123
124                for ((s, t), &w) in x.iter_mut().zip(y.iter_mut()).zip(ws.iter()) {
125                    let p = *s + *t * w;
126                    let q = *s - *t * w;
127
128                    *s = p;
129                    *t = q;
130                }
131            }
132
133            k += 1;
134            width <<= 1;
135        }
136
137        let t = ConstModInt::new(n as u32).inv();
138        for x in f.iter_mut() {
139            *x *= t;
140        }
141    }
142
143    /// 2つの`Vec`を畳み込む。
144    ///
145    /// $(f \ast g)(k) = \sum_{k = i + j} f(i) \times g(j)$
146    pub fn convolve<T>(&self, f: Vec<T>, g: Vec<T>) -> Vec<ConstModInt<P>>
147    where
148        T: Into<ConstModInt<P>>,
149    {
150        if f.is_empty() || g.is_empty() {
151            return vec![];
152        }
153
154        let m = f.len() + g.len() - 1;
155        let n = m.next_power_of_two();
156
157        let mut f: Vec<_> = f.into_iter().map(Into::into).collect();
158        let mut g: Vec<_> = g.into_iter().map(Into::into).collect();
159
160        f.resize(n, ConstModInt::new(0));
161        self.ntt(&mut f);
162
163        g.resize(n, ConstModInt::new(0));
164        self.ntt(&mut g);
165
166        for (f, g) in f.iter_mut().zip(g.into_iter()) {
167            *f *= g;
168        }
169        self.intt(&mut f);
170
171        f
172    }
173
174    /// `convolve(f.clone(), f)`と同等。
175    pub fn convolve_same(&self, mut f: Vec<ConstModInt<P>>) -> Vec<ConstModInt<P>> {
176        if f.is_empty() {
177            return vec![];
178        }
179
180        let n = (f.len() * 2 - 1).next_power_of_two();
181        f.resize(n, ConstModInt::new(0));
182
183        self.ntt(&mut f);
184
185        for x in f.iter_mut() {
186            *x *= *x;
187        }
188
189        self.intt(&mut f);
190        f
191    }
192
193    /// NTTで変換可能な配列の最大長を返す。
194    pub fn max_size(&self) -> usize {
195        Self::MAX_SIZE
196    }
197}
198
199impl<P: PrimeMod> Default for NTT<P> {
200    fn default() -> Self {
201        Self::new()
202    }
203}
204
205#[cfg(test)]
206mod tests {
207
208    use super::*;
209    use rand::Rng;
210
211    #[test]
212    fn test() {
213        type P = Prime<998244353>;
214
215        let ntt = NTT::<P>::new();
216        let ff = ConstModIntBuilder::<P>::new();
217
218        let mut rng = rand::thread_rng();
219
220        let n = rng.gen_range(1..1000);
221        let m = rng.gen_range(1..1000);
222
223        let a = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..P::PRIME_NUM) as u64))
224            .take(n)
225            .collect::<Vec<_>>();
226        let b = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..P::PRIME_NUM) as u64))
227            .take(m)
228            .collect::<Vec<_>>();
229
230        let res = ntt.convolve(a.clone(), b.clone());
231
232        let mut ans = vec![ConstModInt::new(0); n + m - 1];
233
234        for i in 0..n {
235            for j in 0..m {
236                ans[i + j] += a[i] * b[j];
237            }
238        }
239
240        assert_eq!(&res[..n + m - 1], &ans);
241    }
242}