haar_lib/math/
bostan_mori.rs

1//! 線形漸化式で表される数列の第`k`項目を求める。
2use crate::math::convolution::ntt::NTT;
3use crate::math::prime_mod::PrimeMod;
4use crate::num::const_modint::ConstModInt;
5
6/// $a_i = \sum_{j = 1}^d c_j a_{i-j}$を満たす数列$a$の初め$d$項と係数$c$から、数列の第`k`項$a_k$を求める。
7pub fn bostan_mori<P: PrimeMod>(
8    a: Vec<ConstModInt<P>>,
9    c: Vec<ConstModInt<P>>,
10    mut k: u64,
11) -> ConstModInt<P> {
12    assert_eq!(a.len(), c.len());
13
14    let ntt = NTT::<P>::new();
15    let d = a.len();
16
17    let mut q: Vec<ConstModInt<P>> = vec![0.into(); d + 1];
18    q[0] = 1.into();
19    for i in 0..d {
20        q[i + 1] = -c[i];
21    }
22
23    let mut p = ntt.convolve(a, q.clone());
24    p.truncate(d);
25
26    while k > 0 {
27        let mut q1 = q.clone();
28        for i in (1..q1.len()).step_by(2) {
29            q1[i] = -q1[i];
30        }
31
32        let size = (2 * d + 1).next_power_of_two();
33        let mut u = p.clone();
34        u.resize(size, 0.into());
35        ntt.ntt(&mut u);
36
37        q1.resize(size, 0.into());
38        ntt.ntt(&mut q1);
39
40        u.iter_mut().zip(q1.iter()).for_each(|(x, y)| *x *= *y);
41        ntt.intt(&mut u);
42
43        let mut a = q.clone();
44        a.resize(size, 0.into());
45        ntt.ntt(&mut a);
46
47        a.iter_mut().zip(q1).for_each(|(x, y)| *x *= y);
48        ntt.intt(&mut a);
49
50        if k % 2 == 0 {
51            for i in 0..d {
52                p[i] = u[i * 2];
53            }
54        } else {
55            for i in 0..d {
56                p[i] = u[i * 2 + 1];
57            }
58        }
59
60        for i in 0..=d {
61            q[i] = a[i * 2];
62        }
63
64        k >>= 1;
65    }
66
67    p[0]
68}