haar_lib/math/polynomial/
multipoint_eval.rs

1//! 多項式の多点評価
2
3use crate::math::polynomial::Polynomial;
4use crate::math::prime_mod::PrimeMod;
5use crate::num::const_modint::ConstModInt;
6
7/// 多項式の多点評価
8pub trait MultipointEval {
9    /// 多項式の係数の型
10    type Value;
11
12    /// 多項式の多点評価
13    ///
14    /// 多項式$f(x)$に値$p_0, p_1, \cdots, p_m$を代入した結果$f(p_0), f(p_1), \cdots, f(p_m)$を求める。
15    fn multipoint_eval(self, p: Vec<Self::Value>) -> Vec<Self::Value>;
16}
17
18impl<P: PrimeMod> MultipointEval for Polynomial<P> {
19    type Value = ConstModInt<P>;
20
21    fn multipoint_eval(self, p: Vec<Self::Value>) -> Vec<Self::Value> {
22        let m = p.len();
23
24        let k = m.next_power_of_two();
25
26        let mut f = vec![Self::constant(1.into()); k * 2];
27        for i in 0..m {
28            f[i + k] = vec![-p[i], 1.into()].into();
29        }
30        for i in (1..k).rev() {
31            f[i] = f[i << 1].clone() * f[(i << 1) | 1].clone();
32        }
33
34        f[1] = self % f[1].clone();
35
36        for i in 2..k + m {
37            f[i] = f[i >> 1].clone() % std::mem::take(&mut f[i]);
38        }
39
40        f.into_iter()
41            .skip(k)
42            .take(m)
43            .map(|v| v.coeff_of(0))
44            .collect()
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use super::*;
51    use crate::math::polynomial::*;
52    use crate::math::prime_mod::Prime;
53    use crate::num::const_modint::*;
54    use rand::Rng;
55
56    const M: u32 = 998244353;
57    type P = Prime<M>;
58
59    #[test]
60    fn test() {
61        let ff = ConstModIntBuilder::<P>::new();
62
63        let mut rng = rand::thread_rng();
64
65        let n = 100;
66        let a = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..P::PRIME_NUM) as u64))
67            .take(n)
68            .collect::<Vec<_>>();
69        let a = Polynomial::from(a);
70
71        let m = 100;
72        let p = std::iter::repeat_with(|| ff.from_u64(rng.gen_range(0..M) as u64))
73            .take(m)
74            .collect::<Vec<_>>();
75
76        let ans = p.iter().map(|p| a.eval(*p)).collect::<Vec<_>>();
77        let res = a.multipoint_eval(p);
78
79        assert_eq!(res, ans);
80    }
81}