haar_lib/linalg/
const_matrix.rs

1//! 大きさがコンパイル時固定の行列
2use std::ops::{Add, Mul, Neg, Sub};
3
4use crate::algebra::{prod::*, sum::*};
5use crate::impl_algebra;
6use crate::num::one_zero::{One, Zero};
7
8/// `R`×`C`の行列
9#[derive(Clone, Debug, Eq, PartialEq)]
10pub struct Matrix<T, const R: usize, const C: usize> {
11    data: [[T; C]; R],
12}
13
14impl<T: Copy + Zero, const R: usize, const C: usize> Matrix<T, R, C> {
15    /// ゼロ行列を返す。
16    pub fn new() -> Self {
17        let data = [[T::zero(); C]; R];
18        Self { data }
19    }
20}
21
22impl<T: Copy + Zero, const R: usize, const C: usize> Default for Matrix<T, R, C> {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl<T: Copy + Zero + One, const N: usize> Matrix<T, N, N> {
29    /// 単位行列を返す。
30    pub fn unit() -> Self {
31        let mut data = [[T::zero(); N]; N];
32        for (i, r) in data.iter_mut().enumerate() {
33            r[i] = T::one();
34        }
35        Self { data }
36    }
37}
38
39impl<T, const R: usize, const C: usize> Matrix<T, R, C> {
40    /// `i`行`j`列の要素への参照を返す。
41    pub fn get(&self, i: usize, j: usize) -> Option<&T> {
42        let a = self.data.get(i)?;
43        a.get(j)
44    }
45
46    /// `i`行`j`列の要素への可変参照を返す。
47    pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut T> {
48        let a = self.data.get_mut(i)?;
49        a.get_mut(j)
50    }
51}
52
53impl<T: Copy + Zero, const R: usize, const C: usize> Matrix<T, R, C> {
54    /// 転置行列を返す。
55    pub fn transpose(self) -> Matrix<T, C, R> {
56        let mut ret = Matrix::<T, C, R>::new();
57        for i in 0..R {
58            for j in 0..C {
59                ret.data[j][i] = self.data[i][j];
60            }
61        }
62        ret
63    }
64}
65
66impl<T: Copy + Add<Output = T>, const R: usize, const C: usize> Add for Matrix<T, R, C> {
67    type Output = Self;
68    fn add(mut self, other: Self) -> Self {
69        for (a, b) in self.data.iter_mut().zip(other.data) {
70            for (x, y) in a.iter_mut().zip(b) {
71                *x = *x + y;
72            }
73        }
74        self
75    }
76}
77
78impl<T: Copy + Sub<Output = T>, const R: usize, const C: usize> Sub for Matrix<T, R, C> {
79    type Output = Self;
80    fn sub(mut self, other: Self) -> Self {
81        for (a, b) in self.data.iter_mut().zip(other.data) {
82            for (x, y) in a.iter_mut().zip(b) {
83                *x = *x - y;
84            }
85        }
86        self
87    }
88}
89
90impl<T: Copy + Neg<Output = T>, const R: usize, const C: usize> Neg for Matrix<T, R, C> {
91    type Output = Self;
92    fn neg(mut self) -> Self {
93        for a in self.data.iter_mut() {
94            for x in a.iter_mut() {
95                *x = -*x;
96            }
97        }
98        self
99    }
100}
101
102impl<T, const R: usize, const C: usize, const C2: usize> Mul<Matrix<T, C, C2>> for Matrix<T, R, C>
103where
104    T: Copy + Zero + Add<Output = T> + Mul<Output = T>,
105{
106    type Output = Matrix<T, R, C2>;
107    fn mul(self, other: Matrix<T, C, C2>) -> Self::Output {
108        let b = other.transpose();
109        let mut ret = Matrix::new();
110
111        for (r, r2) in ret.data.iter_mut().zip(self.data) {
112            for (x, c) in r.iter_mut().zip(b.data.iter()) {
113                for (y, z) in r2.iter().zip(c.iter()) {
114                    *x = *x + *y * *z;
115                }
116            }
117        }
118
119        ret
120    }
121}
122
123impl_algebra!(
124    [T: Copy + One + Zero + Add<Output = T> + Mul<Output = T>, const N: usize];
125    Prod<Matrix<T, N, N>>;
126    op: |a: Self, b: Self| Self(a.0 * b.0);
127    id: Self(Matrix::unit());
128    assoc;
129);
130
131impl_algebra!(
132    [T: Copy + Zero + Add<Output = T> + Neg<Output = T>, const R: usize, const C: usize];
133    Sum<Matrix<T, R, C>>;
134    op: |a: Self, b: Self| Self(a.0 + b.0);
135    id: Self(Matrix::new());
136    inv: |a: Self| Self(-a.0);
137    assoc;
138    commu;
139);
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn test() {
147        let a = Matrix::<u32, 5, 5>::unit();
148        let b = Matrix::<u32, 5, 3>::new();
149
150        dbg!(a * b);
151    }
152}