haar_lib/math/polynomial/
polynomial_interpolation.rs

1//! 多項式補間
2use crate::math::polynomial::{multipoint_eval::MultipointEval, Polynomial};
3use crate::math::prime_mod::PrimeMod;
4use crate::num::const_modint::ConstModInt;
5
6/// $y_0 = f(x_0), \dots, y_{n-1} = f(x_{n-1})$を満たす多項式$f(x) = c_0 x^0 + c_1 x^1 + \dots + c_{n-1} x^{n-1}$を求める。
7pub fn polynomial_interpolation<P: PrimeMod>(
8    xs: Vec<impl Into<ConstModInt<P>>>,
9    ys: Vec<impl Into<ConstModInt<P>>>,
10) -> Polynomial<P> {
11    assert_eq!(xs.len(), ys.len());
12
13    let n = xs.len();
14    let xs = xs.into_iter().map(Into::into).collect::<Vec<_>>();
15    let ys = ys.into_iter().map(Into::into).collect::<Vec<_>>();
16
17    let g = rec_g(0, n, &xs);
18
19    let mut gd = g.clone();
20    gd.differentiate();
21    let gd = gd.multipoint_eval(xs.clone());
22
23    let (a, b) = rec_frac(0, n, &xs, &ys, &gd);
24
25    let t = a * g;
26    t / b
27}
28
29fn rec_g<P: PrimeMod>(l: usize, r: usize, xs: &[ConstModInt<P>]) -> Polynomial<P> {
30    if r - l == 1 {
31        return vec![-xs[l], 1.into()].into();
32    }
33
34    let m = (l + r) / 2;
35    rec_g(l, m, xs) * rec_g(m, r, xs)
36}
37
38fn rec_frac<P: PrimeMod>(
39    l: usize,
40    r: usize,
41    xs: &[ConstModInt<P>],
42    ys: &[ConstModInt<P>],
43    gs: &[ConstModInt<P>],
44) -> (Polynomial<P>, Polynomial<P>) {
45    if r - l == 1 {
46        return (vec![ys[l]].into(), vec![-xs[l] * gs[l], gs[l]].into());
47    }
48
49    let m = (l + r) / 2;
50
51    let (la, lb) = rec_frac(l, m, xs, ys, gs);
52    let (ra, rb) = rec_frac(m, r, xs, ys, gs);
53
54    let deno = lb.clone() * rb.clone();
55    let nume = la * rb + ra * lb;
56
57    (nume, deno)
58}
59
60#[cfg(test)]
61mod tests {
62    use crate::math::prime_mod::Prime;
63
64    use super::*;
65
66    type P = Prime<998244353>;
67
68    #[test]
69    fn test() {
70        let xs = vec![5, 6, 7, 8, 9];
71        let ys = vec![586, 985, 1534, 2257, 3178];
72
73        let p = polynomial_interpolation::<P>(xs, ys);
74
75        assert_eq!(p, Polynomial::from(vec![1, 2, 3, 4, 0]));
76    }
77}