haar_lib/math/
bostan_mori.rs

1//! 線形漸化式で表される数列の第`k`項目を求める。
2use crate::math::ntt::NTT;
3use crate::num::const_modint::ConstModInt;
4
5/// $a_i = \sum_{j = 1}^d c_j a_{i-j}$を満たす数列$a$の初め$d$項と係数$c$から、数列の第`k`項$a_k$を求める。
6pub fn bostan_mori<const P: u32, const PR: u32>(
7    a: Vec<ConstModInt<P>>,
8    c: Vec<ConstModInt<P>>,
9    mut k: u64,
10    ntt: &NTT<P, PR>,
11) -> ConstModInt<P> {
12    assert_eq!(a.len(), c.len());
13
14    let d = a.len();
15
16    let mut q: Vec<ConstModInt<P>> = vec![0.into(); d + 1];
17    q[0] = 1.into();
18    for i in 0..d {
19        q[i + 1] = -c[i];
20    }
21
22    let mut p = ntt.convolve(a, q.clone());
23    p.truncate(d);
24
25    while k > 0 {
26        let mut q1 = q.clone();
27        for i in (1..q1.len()).step_by(2) {
28            q1[i] = -q1[i];
29        }
30
31        let size = (2 * d + 1).next_power_of_two();
32        let mut u = p.clone();
33        u.resize(size, 0.into());
34        ntt.ntt(&mut u);
35
36        q1.resize(size, 0.into());
37        ntt.ntt(&mut q1);
38
39        u.iter_mut().zip(q1.iter()).for_each(|(x, y)| *x *= *y);
40        ntt.intt(&mut u);
41
42        let mut a = q.clone();
43        a.resize(size, 0.into());
44        ntt.ntt(&mut a);
45
46        a.iter_mut().zip(q1).for_each(|(x, y)| *x *= y);
47        ntt.intt(&mut a);
48
49        if k % 2 == 0 {
50            for i in 0..d {
51                p[i] = u[i * 2];
52            }
53        } else {
54            for i in 0..d {
55                p[i] = u[i * 2 + 1];
56            }
57        }
58
59        for i in 0..=d {
60            q[i] = a[i * 2];
61        }
62
63        k >>= 1;
64    }
65
66    p[0]
67}