haar_lib/linalg/
const_matrix.rs

1//! 大きさがコンパイル時固定の行列
2use std::ops::{Add, Mul, Neg, Sub};
3
4use crate::algebra::semiring::*;
5pub use crate::linalg::traits::*;
6
7/// $R \times C$の行列
8#[derive(Clone, Debug, Eq, PartialEq)]
9pub struct ConstMatrix<S: Semiring, const R: usize, const C: usize> {
10    sr: S,
11    data: [[S::Element; C]; R],
12}
13
14impl<S: Semiring, const R: usize, const C: usize> ConstMatrix<S, R, C>
15where
16    S::Element: Copy,
17{
18    /// ゼロ行列を返す。
19    pub fn zero(semiring: S) -> Self {
20        let data = [[semiring.zero(); C]; R];
21        Self { sr: semiring, data }
22    }
23}
24
25impl<S: Semiring, const N: usize> ConstMatrix<S, N, N>
26where
27    S::Element: Copy,
28{
29    /// 単位行列を返す。
30    pub fn unit(semiring: S) -> Self {
31        let mut data = [[semiring.zero(); N]; N];
32        for (i, r) in data.iter_mut().enumerate() {
33            r[i] = semiring.one();
34        }
35        Self { sr: semiring, data }
36    }
37}
38
39impl<S: Semiring, const R: usize, const C: usize> Matrix for ConstMatrix<S, R, C> {
40    fn width(&self) -> usize {
41        C
42    }
43    fn height(&self) -> usize {
44        R
45    }
46}
47
48impl<S: Semiring, const R: usize, const C: usize> ConstMatrix<S, R, C> {
49    /// `i`行`j`列の要素への参照を返す。
50    pub fn get(&self, i: usize, j: usize) -> Option<&S::Element> {
51        let a = self.data.get(i)?;
52        a.get(j)
53    }
54
55    /// `i`行`j`列の要素への可変参照を返す。
56    pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut S::Element> {
57        let a = self.data.get_mut(i)?;
58        a.get_mut(j)
59    }
60}
61
62impl<S: Semiring, const R: usize, const C: usize> MatrixTranspose for ConstMatrix<S, R, C>
63where
64    S::Element: Copy,
65{
66    type Output = ConstMatrix<S, C, R>;
67    fn transpose(self) -> Self::Output {
68        let mut ret = ConstMatrix::<S, C, R>::zero(self.sr);
69        for i in 0..R {
70            for j in 0..C {
71                ret.data[j][i] = self.data[i][j];
72            }
73        }
74        ret
75    }
76}
77
78impl<S: Semiring, const R: usize, const C: usize> Add for ConstMatrix<S, R, C>
79where
80    S::Element: Copy,
81{
82    type Output = Self;
83    fn add(mut self, other: Self) -> Self {
84        for (a, b) in self.data.iter_mut().zip(other.data) {
85            for (x, y) in a.iter_mut().zip(b) {
86                *x = self.sr.add(*x, y);
87            }
88        }
89        self
90    }
91}
92
93impl<S: Ring, const R: usize, const C: usize> Sub for ConstMatrix<S, R, C>
94where
95    S::Element: Copy,
96{
97    type Output = Self;
98    fn sub(mut self, other: Self) -> Self {
99        for (a, b) in self.data.iter_mut().zip(other.data) {
100            for (x, y) in a.iter_mut().zip(b) {
101                *x = self.sr.sub(*x, y);
102            }
103        }
104        self
105    }
106}
107
108impl<S: Ring, const R: usize, const C: usize> Neg for ConstMatrix<S, R, C>
109where
110    S::Element: Copy,
111{
112    type Output = Self;
113    fn neg(mut self) -> Self {
114        for a in self.data.iter_mut() {
115            for x in a.iter_mut() {
116                *x = self.sr.neg(*x);
117            }
118        }
119        self
120    }
121}
122
123impl<S: Semiring + Clone, const R: usize, const C: usize, const C2: usize>
124    Mul<ConstMatrix<S, C, C2>> for ConstMatrix<S, R, C>
125where
126    S::Element: Copy,
127{
128    type Output = ConstMatrix<S, R, C2>;
129    fn mul(self, other: ConstMatrix<S, C, C2>) -> Self::Output {
130        let b = other.transpose();
131        let mut ret = ConstMatrix::zero(self.sr.clone());
132
133        for (r, r2) in ret.data.iter_mut().zip(self.data) {
134            for (x, c) in r.iter_mut().zip(b.data.iter()) {
135                for (y, z) in r2.iter().zip(c.iter()) {
136                    *x = self.sr.add(*x, self.sr.mul(*y, *z));
137                }
138            }
139        }
140
141        ret
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use crate::algebra::semiring::add_mul::AddMul;
148
149    use super::*;
150
151    #[test]
152    fn test() {
153        let r = AddMul::<u32>::new();
154        let a = ConstMatrix::<_, 5, 5>::unit(r);
155        let b = ConstMatrix::<_, 5, 3>::zero(r);
156
157        dbg!(a * b);
158    }
159}