haar_lib/math/
multipoint_eval.rs

1//! 多項式の多点評価
2
3use crate::math::polynomial::{Polynomial, PolynomialOperator};
4use crate::num::const_modint::ConstModInt;
5
6/// 多項式の多点評価
7pub trait MultipointEval {
8    /// 多項式の型
9    type Poly;
10    /// 多項式の係数の型
11    type Value;
12
13    /// 多項式の多点評価
14    ///
15    /// 多項式$f(x)$に値$p_0, p_1, \cdots, p_m$を代入した結果$f(p_0), f(p_1), \cdots, f(p_m)$を求める。
16    fn multipoint_eval(&self, a: Self::Poly, p: Vec<Self::Value>) -> Vec<Self::Value>;
17}
18
19impl<const P: u32, const PR: u32> MultipointEval for PolynomialOperator<'_, P, PR> {
20    type Poly = Polynomial<P>;
21    type Value = ConstModInt<P>;
22
23    fn multipoint_eval(&self, a: Self::Poly, p: Vec<Self::Value>) -> Vec<Self::Value> {
24        let m = p.len();
25
26        let k = m.next_power_of_two();
27
28        let mut f = vec![Polynomial::constant(1.into()); k * 2];
29        for i in 0..m {
30            f[i + k] = vec![-p[i], 1.into()].into();
31        }
32        for i in (1..k).rev() {
33            f[i] = self.mul(f[i << 1].clone(), f[(i << 1) | 1].clone());
34        }
35
36        f[1] = self.rem(a, f[1].clone());
37
38        for i in 2..k + m {
39            f[i] = self.rem(f[i >> 1].clone(), std::mem::take(&mut f[i]));
40        }
41
42        f.into_iter()
43            .skip(k)
44            .take(m)
45            .map(|v| v.coeff_of(0))
46            .collect()
47    }
48}
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53    use crate::math::{ntt::*, polynomial::*};
54    use crate::num::const_modint::*;
55    use rand::Rng;
56
57    #[test]
58    fn test() {
59        const M: u32 = 998244353;
60
61        let ff = ConstModIntBuilder::<M>;
62        let ntt = NTT::<M, 3>::new();
63        let po = PolynomialOperator::new(&ntt);
64
65        let mut rng = rand::thread_rng();
66
67        let n = 100;
68        let a = (0..n)
69            .map(|_| ff.from_u64(rng.gen_range(0..M) as u64))
70            .collect::<Vec<_>>();
71        let a = Polynomial::from(a);
72
73        let m = 100;
74        let p = (0..m)
75            .map(|_| ff.from_u64(rng.gen_range(0..M) as u64))
76            .collect::<Vec<_>>();
77
78        let ans = p.iter().map(|p| a.eval(*p)).collect::<Vec<_>>();
79        let res = po.multipoint_eval(a, p);
80
81        assert_eq!(res, ans);
82    }
83}