haar_lib/linalg/
semiring.rs1use std::ops::Index;
6
7use crate::algebra::semiring::Semiring;
8
9#[derive(Clone, Debug, PartialEq, Eq)]
11pub struct SemiringMatrix<T> {
12 data: Vec<Vec<T>>,
13 h: usize,
14 w: usize,
15}
16
17impl<T: Semiring + Copy> SemiringMatrix<T> {
18 pub fn zero(h: usize, w: usize) -> Self {
20 let data = vec![vec![T::zero(); w]; h];
21 Self { data, h, w }
22 }
23
24 pub fn unit(n: usize) -> Self {
26 let mut this = Self::zero(n, n);
27 for i in 0..n {
28 this.data[i][i] = T::one();
29 }
30 this
31 }
32
33 pub fn transpose(self) -> Self {
35 let a = self;
36 let mut ret = Self::zero(a.w, a.h);
37 for i in 0..a.h {
38 for j in 0..a.w {
39 ret.data[j][i] = a.data[i][j];
40 }
41 }
42 ret
43 }
44
45 pub fn try_mul(self, b: Self) -> Option<Self> {
49 let a = self;
50 if a.w != b.h {
51 return None;
52 }
53
54 let n = a.h;
55 let l = b.w;
56 let b = b.transpose();
57 let mut ret = Self::zero(n, l);
58
59 for (r, r2) in ret.data.iter_mut().zip(a.data.iter()) {
60 for (x, c) in r.iter_mut().zip(b.data.iter()) {
61 for (y, z) in r2.iter().zip(c.iter()) {
62 *x = T::add(*x, T::mul(*y, *z));
63 }
64 }
65 }
66
67 Some(ret)
68 }
69
70 pub fn pow(self, mut n: u64) -> Self {
72 let mut a = self;
73 assert_eq!(a.h, a.w);
74
75 let mut ret = Self::unit(a.h);
76
77 while n > 0 {
78 if n % 2 == 1 {
79 ret = ret.try_mul(a.clone()).unwrap();
80 }
81 a = a.clone().try_mul(a).unwrap();
82 n >>= 1;
83 }
84
85 ret
86 }
87
88 pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut T> {
90 let a = self.data.get_mut(i)?;
91 a.get_mut(j)
92 }
93
94 pub fn get(&self, i: usize, j: usize) -> Option<&T> {
96 let a = self.data.get(i)?;
97 a.get(j)
98 }
99}
100
101impl<T> Index<usize> for SemiringMatrix<T> {
102 type Output = [T];
103 fn index(&self, i: usize) -> &Self::Output {
104 &self.data[i]
105 }
106}