haar_lib/math/convolution/
conv_mul_modp.rs1use std::iter::successors;
3
4use crate::math::prime_mod::PrimeMod;
5use crate::sort_with;
6use crate::{
7 math::{convolution::ntt::NTT, primitive_root::primitive_root},
8 num::const_modint::ConstModInt,
9};
10
11pub fn convolution_mul_modp<P: PrimeMod>(
16 mut a: Vec<ConstModInt<P>>,
17 mut b: Vec<ConstModInt<P>>,
18) -> Vec<ConstModInt<P>> {
19 assert_eq!(a.len(), b.len());
20 let p = a.len();
21 let p_root = primitive_root(p as u32) as usize;
22
23 let ntt = NTT::<P>::new();
24
25 let mut index = vec![0; p];
26 successors(Some(1), |&s| Some(s * p_root % p))
27 .take(p)
28 .enumerate()
29 .for_each(|(i, s)| index[s] = i);
30
31 let mut zero = a[0] * b[0];
32 for i in 1..p {
33 zero += a[0] * b[i] + a[i] * b[0];
34 }
35
36 a[0] = 0.into();
37 b[0] = 0.into();
38
39 sort_with!(|&i, &j| index[i].cmp(&index[j]), a, b);
40
41 let c = ntt.convolve(a, b);
42
43 let mut ret = vec![0.into(); p];
44
45 successors(Some(1), |&s| Some(s * p_root % p))
46 .zip(c)
47 .for_each(|(s, x)| ret[s] += x);
48
49 ret[0] = zero;
50
51 ret
52}
53
54#[cfg(test)]
55mod tests {
56 use crate::{
57 iter::collect::CollectVec,
58 math::prime_mod::Prime,
59 num::{const_modint::ConstModIntBuilder, ff::*},
60 };
61
62 use super::*;
63 use rand::Rng;
64
65 type P = Prime<998244353>;
66
67 #[test]
68 fn test() {
69 let p = 1009;
70 let modulo = ConstModIntBuilder::<P>::new();
71 let mut rng = rand::thread_rng();
72
73 let a = std::iter::repeat_with(|| modulo.from_u64(rng.gen::<u64>()))
74 .take(p)
75 .collect_vec();
76 let b = std::iter::repeat_with(|| modulo.from_u64(rng.gen::<u64>()))
77 .take(p)
78 .collect_vec();
79
80 let mut ans = vec![modulo.from_u64(0); p];
81 for i in 0..p {
82 for j in 0..p {
83 ans[i * j % p] += a[i] * b[j];
84 }
85 }
86
87 let res = convolution_mul_modp::<P>(a, b);
88
89 assert_eq!(ans, res);
90 }
91}