haar_lib/math/polynomial/
shift_sampling_points.rs

1//! 多項式の標本点シフト
2//!
3//! # References
4//! - <https://suisen-cp.github.io/cp-library-cpp/library/polynomial/shift_of_sampling_points.hpp.html>
5//!
6//! # Problems
7//! - <https://judge.yosupo.jp/problem/shift_of_sampling_points_of_polynomial>
8
9use crate::{
10    math::convolution::ntt::NTT,
11    math::factorial::FactorialTable,
12    math::prime_mod::PrimeMod,
13    num::const_modint::{ConstModInt, ConstModIntBuilder},
14};
15
16/// $N$次未満の多項式$f(x)$について、$f(0), f(1), \dots, f(N-1)$から$f(c), f(c + 1), \dots, f(c + M - 1)$を求める。
17pub fn shift_sampling_points<P: PrimeMod>(
18    f: Vec<impl Into<ConstModInt<P>>>,
19    c: u32,
20    m: usize,
21) -> Vec<ConstModInt<P>> {
22    let f = f.into_iter().map(Into::into).collect::<Vec<_>>();
23
24    let n = f.len();
25    let ntt = NTT::<P>::new();
26    let ft = FactorialTable::new(n.max(m), ConstModIntBuilder::<P>::new());
27
28    let a = {
29        let f = f
30            .into_iter()
31            .enumerate()
32            .map(|(i, x)| x * ft.inv_facto(i))
33            .collect();
34        let g = (0..n)
35            .map(|i| {
36                if i % 2 == 0 {
37                    ft.inv_facto(i)
38                } else {
39                    -ft.inv_facto(i)
40                }
41            })
42            .collect();
43        ntt.convolve(f, g)[..n].to_vec()
44    };
45
46    let b = {
47        let f = a
48            .into_iter()
49            .enumerate()
50            .rev()
51            .map(|(i, x)| x * ft.facto(i))
52            .collect();
53        let mut p = ConstModInt::new(1);
54        let g = (0..n)
55            .map(|i| {
56                let ret = p * ft.inv_facto(i);
57                p *= (c as i64 - i as i64).into();
58                ret
59            })
60            .collect();
61        ntt.convolve(f, g)[..n].to_vec()
62    };
63
64    let mut ret = {
65        let f = b
66            .into_iter()
67            .rev()
68            .enumerate()
69            .map(|(i, x)| x * ft.inv_facto(i))
70            .collect();
71        let g = (0..m).map(|i| ft.inv_facto(i)).collect();
72        ntt.convolve(f, g)[..m].to_vec()
73    };
74
75    ret.iter_mut()
76        .enumerate()
77        .for_each(|(i, x)| *x *= ft.facto(i));
78    ret
79}