haar_lib/math/convolution/
conv_mul_modp.rs1use std::iter::successors;
3
4use crate::sort_with;
5use crate::{
6 math::{ntt::NTT, primitive_root::primitive_root},
7 num::const_modint::ConstModInt,
8};
9
10pub fn convolution_mul_modp<const M: u32, const PR: u32>(
12 mut a: Vec<ConstModInt<M>>,
13 mut b: Vec<ConstModInt<M>>,
14 ntt: &NTT<M, PR>,
15) -> Vec<ConstModInt<M>> {
16 assert_eq!(a.len(), b.len());
17 let p = a.len();
18 let p_root = primitive_root(p as u64).unwrap() as usize;
19
20 let mut index = vec![0; p];
21 successors(Some(1), |&s| Some(s * p_root % p))
22 .take(p)
23 .enumerate()
24 .for_each(|(i, s)| index[s] = i);
25
26 let mut zero = a[0] * b[0];
27 for i in 1..p {
28 zero += a[0] * b[i] + a[i] * b[0];
29 }
30
31 a[0] = 0.into();
32 b[0] = 0.into();
33
34 sort_with!(|&i, &j| index[i].cmp(&index[j]), a, b);
35
36 let c = ntt.convolve(a, b);
37
38 let mut ret = vec![0.into(); p];
39
40 successors(Some(1), |&s| Some(s * p_root % p))
41 .zip(c)
42 .for_each(|(s, x)| ret[s] += x);
43
44 ret[0] = zero;
45
46 ret
47}
48
49#[cfg(test)]
50mod tests {
51 use crate::{
52 iter::collect::CollectVec,
53 math::ntt::NTT998244353,
54 num::{const_modint::ConstModIntBuilder, ff::FF},
55 };
56
57 use super::*;
58 use rand::Rng;
59
60 #[test]
61 fn test() {
62 let p = 1009;
63 let modulo = ConstModIntBuilder::<998244353>;
64 let mut rng = rand::thread_rng();
65 let ntt = NTT998244353::new();
66
67 let a = (0..p)
68 .map(|_| modulo.from_u64(rng.gen::<u64>()))
69 .collect_vec();
70 let b = (0..p)
71 .map(|_| modulo.from_u64(rng.gen::<u64>()))
72 .collect_vec();
73
74 let mut ans = vec![modulo.from_u64(0); p];
75 for i in 0..p {
76 for j in 0..p {
77 ans[i * j % p] += a[i] * b[j];
78 }
79 }
80
81 let res = convolution_mul_modp(a, b, &ntt);
82
83 assert_eq!(ans, res);
84 }
85}