haar_lib/math/convolution/
conv_mul_modp.rs

1//! 素数$P$に対して、$c_k = \sum_{i \times j = k \pmod P} a_i b_j$を満たす$c$を求める。
2use std::iter::successors;
3
4use crate::sort_with;
5use crate::{
6    math::{ntt::NTT, primitive_root::primitive_root},
7    num::const_modint::ConstModInt,
8};
9
10/// 素数$P$に対して、$c_k = \sum_{i \times j = k \pmod P} a_i b_j$を満たす$c$を求める。
11pub fn convolution_mul_modp<const M: u32, const PR: u32>(
12    mut a: Vec<ConstModInt<M>>,
13    mut b: Vec<ConstModInt<M>>,
14    ntt: &NTT<M, PR>,
15) -> Vec<ConstModInt<M>> {
16    assert_eq!(a.len(), b.len());
17    let p = a.len();
18    let p_root = primitive_root(p as u64).unwrap() as usize;
19
20    let mut index = vec![0; p];
21    successors(Some(1), |&s| Some(s * p_root % p))
22        .take(p)
23        .enumerate()
24        .for_each(|(i, s)| index[s] = i);
25
26    let mut zero = a[0] * b[0];
27    for i in 1..p {
28        zero += a[0] * b[i] + a[i] * b[0];
29    }
30
31    a[0] = 0.into();
32    b[0] = 0.into();
33
34    sort_with!(|&i, &j| index[i].cmp(&index[j]), a, b);
35
36    let c = ntt.convolve(a, b);
37
38    let mut ret = vec![0.into(); p];
39
40    successors(Some(1), |&s| Some(s * p_root % p))
41        .zip(c)
42        .for_each(|(s, x)| ret[s] += x);
43
44    ret[0] = zero;
45
46    ret
47}
48
49#[cfg(test)]
50mod tests {
51    use crate::{
52        iter::collect::CollectVec,
53        math::ntt::NTT998244353,
54        num::{const_modint::ConstModIntBuilder, ff::FF},
55    };
56
57    use super::*;
58    use rand::Rng;
59
60    #[test]
61    fn test() {
62        let p = 1009;
63        let modulo = ConstModIntBuilder::<998244353>;
64        let mut rng = rand::thread_rng();
65        let ntt = NTT998244353::new();
66
67        let a = (0..p)
68            .map(|_| modulo.from_u64(rng.gen::<u64>()))
69            .collect_vec();
70        let b = (0..p)
71            .map(|_| modulo.from_u64(rng.gen::<u64>()))
72            .collect_vec();
73
74        let mut ans = vec![modulo.from_u64(0); p];
75        for i in 0..p {
76            for j in 0..p {
77                ans[i * j % p] += a[i] * b[j];
78            }
79        }
80
81        let res = convolution_mul_modp(a, b, &ntt);
82
83        assert_eq!(ans, res);
84    }
85}