haar_lib/math/
multipoint_eval.rs

1//! 多項式の多点評価
2
3use crate::math::polynomial::{Polynomial, PolynomialOperator};
4use crate::num::const_modint::ConstModInt;
5
6/// 多項式の多点評価
7pub trait MultipointEval {
8    /// 多項式の型
9    type Poly;
10    /// 多項式の係数の型
11    type Value;
12
13    /// 多項式の多点評価
14    ///
15    /// 多項式$f(x)$に値$p_0, p_1, \cdots, p_m$を代入した結果$f(p_0), f(p_1), \cdots, f(p_m)$を求める。
16    fn multipoint_eval(&self, a: Self::Poly, p: Vec<Self::Value>) -> Vec<Self::Value>;
17}
18
19impl<const P: u32, const PR: u32> MultipointEval for PolynomialOperator<'_, P, PR> {
20    type Poly = Polynomial<P>;
21    type Value = ConstModInt<P>;
22
23    fn multipoint_eval(&self, a: Self::Poly, p: Vec<Self::Value>) -> Vec<Self::Value> {
24        let m = p.len();
25
26        let mut k = 1;
27        while k < m {
28            k *= 2;
29        }
30
31        let mut f = vec![Polynomial::constant(ConstModInt::new(1)); k * 2];
32        for i in 0..m {
33            f[i + k] = Polynomial::from(vec![-p[i], ConstModInt::new(1)]);
34        }
35        for i in (1..k).rev() {
36            f[i] = self.mul(f[i << 1].clone(), f[(i << 1) | 1].clone());
37        }
38
39        f[1] = self.divmod(a, f[1].clone()).1;
40
41        for i in 2..k + m {
42            f[i] = self.divmod(f[i >> 1].clone(), f[i].clone()).1;
43        }
44
45        f.into_iter()
46            .skip(k)
47            .take(m)
48            .map(|v| v.coeff_of(0))
49            .collect()
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56    use crate::math::{ntt::*, polynomial::*};
57    use crate::num::const_modint::*;
58    use rand::Rng;
59
60    #[test]
61    fn test() {
62        const M: u32 = 998244353;
63
64        let ff = ConstModIntBuilder::<M>;
65        let ntt = NTT::<M, 3>::new();
66        let po = PolynomialOperator::new(&ntt);
67
68        let mut rng = rand::thread_rng();
69
70        let n = 100;
71        let a = (0..n)
72            .map(|_| ff.from_u64(rng.gen_range(0..M) as u64))
73            .collect::<Vec<_>>();
74        let a = Polynomial::from(a);
75
76        let m = 100;
77        let p = (0..m)
78            .map(|_| ff.from_u64(rng.gen_range(0..M) as u64))
79            .collect::<Vec<_>>();
80
81        let ans = p.iter().map(|p| a.eval(*p)).collect::<Vec<_>>();
82        let res = po.multipoint_eval(a, p);
83
84        assert_eq!(res, ans);
85    }
86}