haar_lib/linalg/mod_p/
inverse.rs

1//! 逆行列 (mod 素数)
2use crate::num::{ff::FFElem, one_zero::*};
3
4/// 素数mod p上での逆行列を求める。
5///
6/// **Time complexity** $O(n^3)$
7pub fn inverse<T>(mut b: Vec<Vec<T>>) -> Option<Vec<Vec<T>>>
8where
9    T: FFElem + Copy + Zero + One,
10{
11    let n = b.len();
12
13    assert!(b.iter().all(|r| r.len() == n));
14
15    for (i, bi) in b.iter_mut().enumerate() {
16        bi.resize(2 * n, T::zero());
17        bi[i + n] = T::one();
18    }
19
20    for i in 0..n {
21        let q = (i..n).find(|&j| b[j][i].value() != 0)?;
22
23        b.swap(i, q);
24
25        let d = b[i][i].inv();
26
27        for x in b[i].iter_mut() {
28            *x *= d;
29        }
30
31        let d = b[i][i].inv();
32
33        let bi = b.swap_remove(i);
34
35        for bj in b.iter_mut() {
36            let d = bj[i] * d;
37
38            for (x, y) in bj.iter_mut().zip(bi.iter()) {
39                *x -= *y * d;
40            }
41        }
42
43        b.push(bi);
44        b.swap(i, n - 1);
45    }
46
47    let ret = b.into_iter().map(|a| a[n..].to_vec()).collect();
48
49    Some(ret)
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55    use crate::num::const_modint::*;
56
57    fn convert<U, T>(a: Vec<Vec<T>>) -> Vec<Vec<U>>
58    where
59        U: From<T>,
60    {
61        a.into_iter()
62            .map(|b| b.into_iter().map(From::from).collect())
63            .collect()
64    }
65
66    #[test]
67    fn test() {
68        const P: u32 = 998244353;
69
70        let a = vec![vec![3, 1, 4], vec![1, 5, 9], vec![2, 6, 5]];
71        let a = convert::<ConstModInt<P>, _>(a);
72        let res = inverse(a);
73        let res = res.map(convert::<u32, _>);
74        assert_eq!(
75            res,
76            Some(vec![
77                vec![188557267, 255106890, 587855008],
78                vec![122007643, 987152749, 321656514],
79                vec![576763404, 310564910, 976061145]
80            ])
81        );
82
83        let a = vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]];
84        let a = convert::<ConstModInt<P>, _>(a);
85        let res = inverse(a);
86        let res = res.map(convert::<u32, _>);
87        assert_eq!(res, None);
88
89        let a = vec![vec![0, 1], vec![1, 0]];
90        let a = convert::<ConstModInt<P>, _>(a);
91        let res = inverse(a);
92        let res = res.map(convert::<u32, _>);
93        assert_eq!(res, Some(vec![vec![0, 1], vec![1, 0]]));
94    }
95}