haar_lib/math/
multipoint_eval.rs1use crate::math::polynomial::{Polynomial, PolynomialOperator};
4use crate::num::const_modint::ConstModInt;
5
6pub trait MultipointEval {
8 type Poly;
10 type Value;
12
13 fn multipoint_eval(&self, a: Self::Poly, p: Vec<Self::Value>) -> Vec<Self::Value>;
17}
18
19impl<const P: u32, const PR: u32> MultipointEval for PolynomialOperator<'_, P, PR> {
20 type Poly = Polynomial<P>;
21 type Value = ConstModInt<P>;
22
23 fn multipoint_eval(&self, a: Self::Poly, p: Vec<Self::Value>) -> Vec<Self::Value> {
24 let m = p.len();
25
26 let mut k = 1;
27 while k < m {
28 k *= 2;
29 }
30
31 let mut f = vec![Polynomial::constant(ConstModInt::new(1)); k * 2];
32 for i in 0..m {
33 f[i + k] = Polynomial::from(vec![-p[i], ConstModInt::new(1)]);
34 }
35 for i in (1..k).rev() {
36 f[i] = self.mul(f[i << 1].clone(), f[(i << 1) | 1].clone());
37 }
38
39 f[1] = self.divmod(a, f[1].clone()).1;
40
41 for i in 2..k + m {
42 f[i] = self.divmod(f[i >> 1].clone(), f[i].clone()).1;
43 }
44
45 f.into_iter()
46 .skip(k)
47 .take(m)
48 .map(|v| v.coeff_of(0))
49 .collect()
50 }
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56 use crate::math::{ntt::*, polynomial::*};
57 use crate::num::const_modint::*;
58 use rand::Rng;
59
60 #[test]
61 fn test() {
62 const M: u32 = 998244353;
63
64 let ff = ConstModIntBuilder::<M>;
65 let ntt = NTT::<M, 3>::new();
66 let po = PolynomialOperator::new(&ntt);
67
68 let mut rng = rand::thread_rng();
69
70 let n = 100;
71 let a = (0..n)
72 .map(|_| ff.from_u64(rng.gen_range(0..M) as u64))
73 .collect::<Vec<_>>();
74 let a = Polynomial::from(a);
75
76 let m = 100;
77 let p = (0..m)
78 .map(|_| ff.from_u64(rng.gen_range(0..M) as u64))
79 .collect::<Vec<_>>();
80
81 let ans = p.iter().map(|p| a.eval(*p)).collect::<Vec<_>>();
82 let res = po.multipoint_eval(a, p);
83
84 assert_eq!(res, ans);
85 }
86}