haar_lib/math/
multipoint_eval.rs1use crate::math::polynomial::{Polynomial, PolynomialOperator};
4use crate::num::const_modint::ConstModInt;
5
6pub trait MultipointEval {
8 type Poly;
10 type Value;
12
13 fn multipoint_eval(&self, a: Self::Poly, p: Vec<Self::Value>) -> Vec<Self::Value>;
17}
18
19impl<const P: u32, const PR: u32> MultipointEval for PolynomialOperator<'_, P, PR> {
20 type Poly = Polynomial<P>;
21 type Value = ConstModInt<P>;
22
23 fn multipoint_eval(&self, a: Self::Poly, p: Vec<Self::Value>) -> Vec<Self::Value> {
24 let m = p.len();
25
26 let k = m.next_power_of_two();
27
28 let mut f = vec![Polynomial::constant(1.into()); k * 2];
29 for i in 0..m {
30 f[i + k] = vec![-p[i], 1.into()].into();
31 }
32 for i in (1..k).rev() {
33 f[i] = self.mul(f[i << 1].clone(), f[(i << 1) | 1].clone());
34 }
35
36 f[1] = self.rem(a, f[1].clone());
37
38 for i in 2..k + m {
39 f[i] = self.rem(f[i >> 1].clone(), std::mem::take(&mut f[i]));
40 }
41
42 f.into_iter()
43 .skip(k)
44 .take(m)
45 .map(|v| v.coeff_of(0))
46 .collect()
47 }
48}
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53 use crate::math::{ntt::*, polynomial::*};
54 use crate::num::const_modint::*;
55 use rand::Rng;
56
57 #[test]
58 fn test() {
59 const M: u32 = 998244353;
60
61 let ff = ConstModIntBuilder::<M>;
62 let ntt = NTT::<M, 3>::new();
63 let po = PolynomialOperator::new(&ntt);
64
65 let mut rng = rand::thread_rng();
66
67 let n = 100;
68 let a = (0..n)
69 .map(|_| ff.from_u64(rng.gen_range(0..M) as u64))
70 .collect::<Vec<_>>();
71 let a = Polynomial::from(a);
72
73 let m = 100;
74 let p = (0..m)
75 .map(|_| ff.from_u64(rng.gen_range(0..M) as u64))
76 .collect::<Vec<_>>();
77
78 let ans = p.iter().map(|p| a.eval(*p)).collect::<Vec<_>>();
79 let res = po.multipoint_eval(a, p);
80
81 assert_eq!(res, ans);
82 }
83}