haar_lib/math/
berlekamp_massey.rs

1//! 線形漸化式を求める。
2//!
3//! # Problems
4//! - <https://judge.yosupo.jp/problem/find_linear_recurrence>
5use crate::num::ff::*;
6use std::iter::zip;
7use std::ops::Add;
8
9/// $N$項の数列$a_0, a_1, \ldots, a_{N-1}$から、
10/// 最短の線形漸化式$a_i = c_1 a_{i-1} + c_2 a_{i-2} + \dots + c_d a_{i-d}$の係数$c_i$を求める。
11pub fn berlekamp_massey<Modulo: FF>(
12    mut s: Vec<Modulo::Element>,
13    modulo: Modulo,
14) -> Vec<Modulo::Element>
15where
16    Modulo::Element: FFElem + Copy,
17{
18    let zero = modulo.from_u64(0);
19    let one = modulo.from_u64(1);
20    let len = s.len();
21    let mut c = vec![one];
22    let mut p = vec![one];
23    let mut l = 0;
24    let mut m = 1;
25    let mut b = one;
26
27    c.reserve(len);
28    s.reverse();
29
30    for n in 0..len {
31        let d = s[len - n - 1]
32            + zip(c.iter().skip(1), s.iter().skip(len - n))
33                .map(|(&c, &s)| c * s)
34                .fold(zero, Add::add);
35
36        if d == zero {
37            m += 1;
38        } else if 2 * l <= n {
39            let temp = c.clone();
40            if c.len() < p.len() + m {
41                c.resize(p.len() + m, zero);
42            }
43            let t = d / b;
44
45            for (c, p) in c.iter_mut().skip(m).zip(p.iter()) {
46                *c -= t * *p;
47            }
48
49            l = n + 1 - l;
50            p = temp;
51            b = d;
52            m = 1;
53        } else {
54            if c.len() < p.len() + m {
55                c.resize(p.len() + m, zero);
56            }
57            let t = d / b;
58
59            for (c, p) in c.iter_mut().skip(m).zip(p.iter()) {
60                *c -= t * *p;
61            }
62
63            m += 1;
64        }
65    }
66
67    c.into_iter().skip(1).take(l).map(|x| -x).collect()
68}
69
70#[cfg(test)]
71mod tests {
72    use std::ops::{Add, Mul};
73
74    use crate::{
75        iter::collect::CollectVec,
76        num::{const_modint::ConstModIntBuilder, one_zero::Zero},
77    };
78
79    use super::*;
80
81    fn generate<T>(prefix: &[T], coeffs: &[T]) -> Vec<T>
82    where
83        T: Copy + Add<Output = T> + Mul<Output = T> + Zero,
84    {
85        assert_eq!(prefix.len(), coeffs.len());
86        let n = prefix.len();
87
88        let mut ret = prefix.to_vec();
89
90        for _ in 0..n {
91            let a = ret
92                .iter()
93                .rev()
94                .zip(coeffs.iter())
95                .map(|(&a, &c)| a * c)
96                .fold(T::zero(), std::ops::Add::add);
97
98            ret.push(a);
99        }
100
101        ret
102    }
103
104    #[test]
105    fn test() {
106        let ff = ConstModIntBuilder::<998244353>;
107
108        let a = vec![1, 2, 3, 4, 5];
109        let a = a.into_iter().map(|x| ff.from_u64(x)).collect_vec();
110        let c = vec![2, 3, 2, 8, 5];
111        let c = c.into_iter().map(|x| ff.from_u64(x)).collect_vec();
112
113        let a = generate(&a, &c);
114        let res = berlekamp_massey(a, ff);
115
116        assert_eq!(res, c);
117    }
118}