haar_lib/math/convolution/
conv_mul_modp.rs

1//! 添字積$\pmod p$畳み込み
2use std::iter::successors;
3
4use crate::math::prime_mod::PrimeMod;
5use crate::sort_with;
6use crate::{
7    math::{convolution::ntt::NTT, primitive_root::primitive_root},
8    num::const_modint::ConstModInt,
9};
10
11/// 素数$P$に対して、$c_k = \sum_{i \times j = k \pmod P} a_i b_j$を満たす$c$を求める。
12///
13/// # Requirements
14/// `f.len()` = `g.len()`
15pub fn convolution_mul_modp<P: PrimeMod>(
16    mut a: Vec<ConstModInt<P>>,
17    mut b: Vec<ConstModInt<P>>,
18) -> Vec<ConstModInt<P>> {
19    assert_eq!(a.len(), b.len());
20    let p = a.len();
21    let p_root = primitive_root(p as u32) as usize;
22
23    let ntt = NTT::<P>::new();
24
25    let mut index = vec![0; p];
26    successors(Some(1), |&s| Some(s * p_root % p))
27        .take(p)
28        .enumerate()
29        .for_each(|(i, s)| index[s] = i);
30
31    let mut zero = a[0] * b[0];
32    for i in 1..p {
33        zero += a[0] * b[i] + a[i] * b[0];
34    }
35
36    a[0] = 0.into();
37    b[0] = 0.into();
38
39    sort_with!(|&i, &j| index[i].cmp(&index[j]), a, b);
40
41    let c = ntt.convolve(a, b);
42
43    let mut ret = vec![0.into(); p];
44
45    successors(Some(1), |&s| Some(s * p_root % p))
46        .zip(c)
47        .for_each(|(s, x)| ret[s] += x);
48
49    ret[0] = zero;
50
51    ret
52}
53
54#[cfg(test)]
55mod tests {
56    use crate::{
57        iter::collect::CollectVec,
58        math::prime_mod::Prime,
59        num::{const_modint::ConstModIntBuilder, ff::*},
60    };
61
62    use super::*;
63    use rand::Rng;
64
65    type P = Prime<998244353>;
66
67    #[test]
68    fn test() {
69        let p = 1009;
70        let modulo = ConstModIntBuilder::<P>::new();
71        let mut rng = rand::thread_rng();
72
73        let a = std::iter::repeat_with(|| modulo.from_u64(rng.gen::<u64>()))
74            .take(p)
75            .collect_vec();
76        let b = std::iter::repeat_with(|| modulo.from_u64(rng.gen::<u64>()))
77            .take(p)
78            .collect_vec();
79
80        let mut ans = vec![modulo.from_u64(0); p];
81        for i in 0..p {
82            for j in 0..p {
83                ans[i * j % p] += a[i] * b[j];
84            }
85        }
86
87        let res = convolution_mul_modp::<P>(a, b);
88
89        assert_eq!(ans, res);
90    }
91}