haar_lib/math/
polynomial_interpolation.rs

1//! 多項式補間
2use crate::math::multipoint_eval::MultipointEval;
3use crate::math::ntt::*;
4use crate::math::polynomial::{Polynomial, PolynomialOperator};
5use crate::num::const_modint::ConstModInt;
6
7/// $y_0 = f(x_0), \dots, y_{n-1} = f(x_{n-1})$を満たす多項式$f(x) = c_0 x^0 + c_1 x^1 + \dots + c_{n-1} x^{n-1}$を求める。
8pub fn polynomial_interpolation<const P: u32, const PR: u32>(
9    xs: Vec<impl Into<ConstModInt<P>>>,
10    ys: Vec<impl Into<ConstModInt<P>>>,
11    ntt: &NTT<P, PR>,
12) -> Polynomial<P> {
13    assert_eq!(xs.len(), ys.len());
14    let n = xs.len();
15    let xs = xs.into_iter().map(Into::into).collect::<Vec<_>>();
16    let ys = ys.into_iter().map(Into::into).collect::<Vec<_>>();
17
18    let po = PolynomialOperator::new(ntt);
19
20    let g = rec_g(0, n, &xs, ntt);
21
22    let mut gd = g.clone();
23    gd.differentiate();
24    let gd = po.multipoint_eval(gd, xs.clone());
25
26    let (a, b) = rec_frac(0, n, &xs, &ys, &gd, ntt);
27
28    let t = po.mul(a, g);
29    po.div(t, b)
30}
31
32fn rec_g<const P: u32, const PR: u32>(
33    l: usize,
34    r: usize,
35    xs: &[ConstModInt<P>],
36    ntt: &NTT<P, PR>,
37) -> Polynomial<P> {
38    if r - l == 1 {
39        return vec![-xs[l], 1.into()].into();
40    }
41
42    let po = PolynomialOperator::new(ntt);
43    let m = (l + r) / 2;
44    po.mul(rec_g(l, m, xs, ntt), rec_g(m, r, xs, ntt))
45}
46
47fn rec_frac<const P: u32, const PR: u32>(
48    l: usize,
49    r: usize,
50    xs: &[ConstModInt<P>],
51    ys: &[ConstModInt<P>],
52    gs: &[ConstModInt<P>],
53    ntt: &NTT<P, PR>,
54) -> (Polynomial<P>, Polynomial<P>) {
55    if r - l == 1 {
56        return (vec![ys[l]].into(), vec![-xs[l] * gs[l], gs[l]].into());
57    }
58
59    let m = (l + r) / 2;
60
61    let (la, lb) = rec_frac(l, m, xs, ys, gs, ntt);
62    let (ra, rb) = rec_frac(m, r, xs, ys, gs, ntt);
63
64    let po = PolynomialOperator::new(ntt);
65
66    let deno = po.mul(lb.clone(), rb.clone());
67    let nume = po.mul(la, rb) + po.mul(ra, lb);
68
69    (nume, deno)
70}