haar_lib/linalg/mod_m/
matrix.rs

1//! `h`×`w`行列
2use crate::num::ff::*;
3use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub, SubAssign};
4
5/// `h`×`w`行列
6#[derive(Clone, PartialEq, Eq)]
7pub struct Matrix<Modulo: FF> {
8    h: usize,
9    w: usize,
10    modulo: Modulo,
11    data: Vec<Vec<Modulo::Element>>,
12}
13
14impl<Modulo: FF> Matrix<Modulo>
15where
16    Modulo::Element: FFElem + Copy,
17{
18    /// `h`×`w`の零行列を作る。
19    pub fn new(h: usize, w: usize, modulo: Modulo) -> Self {
20        Self {
21            h,
22            w,
23            data: vec![vec![modulo.from_u64(0); w]; h],
24            modulo,
25        }
26    }
27
28    /// [`Vec<Vec<u32>>`]から[`Matrix<Modulo>`]を作る。
29    pub fn from_vec_2d(other: Vec<Vec<u32>>, modulo: Modulo) -> Self {
30        let h = other.len();
31        assert!(h > 0);
32        let w = other[0].len();
33        assert!(other.iter().all(|r| r.len() == w));
34
35        let other = other
36            .into_iter()
37            .map(|a| {
38                a.into_iter()
39                    .map(|x| modulo.from_u64(x as u64))
40                    .collect::<Vec<_>>()
41            })
42            .collect();
43
44        Self {
45            h,
46            w,
47            data: other,
48            modulo,
49        }
50    }
51
52    // pub fn to_vec(&self) -> Vec<Vec<T>> {
53    //     self.data.clone()
54    // }
55
56    /// 行列の行数を返す。
57    pub fn height(&self) -> usize {
58        self.h
59    }
60
61    /// 行列の列数を返す。
62    pub fn width(&self) -> usize {
63        self.w
64    }
65
66    /// `w`×`h`の転置行列を作る。
67    pub fn transpose(self) -> Self {
68        let mut ret = Self::new(self.w, self.h, self.modulo);
69        for i in 0..self.h {
70            for j in 0..self.w {
71                ret.data[j][i] = self.data[i][j];
72            }
73        }
74        ret
75    }
76
77    /// `i`行`j`列の要素への可変参照を返す。
78    pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut Modulo::Element> {
79        let a = self.data.get_mut(i)?;
80        a.get_mut(j)
81    }
82}
83
84impl<Modulo: FF> AddAssign for Matrix<Modulo>
85where
86    Modulo::Element: FFElem + Copy,
87{
88    fn add_assign(&mut self, other: Self) {
89        assert!(self.h == other.h && self.w == other.h);
90        for i in 0..self.h {
91            for j in 0..self.w {
92                self.data[i][j] = self.data[i][j] + other.data[i][j];
93            }
94        }
95    }
96}
97
98impl<Modulo: FF> SubAssign for Matrix<Modulo>
99where
100    Modulo::Element: FFElem + Copy,
101{
102    fn sub_assign(&mut self, other: Self) {
103        assert!(self.h == other.h && self.w == other.h);
104        for i in 0..self.h {
105            for j in 0..self.w {
106                self.data[i][j] = self.data[i][j] - other.data[i][j];
107            }
108        }
109    }
110}
111
112impl<Modulo: FF> MulAssign for Matrix<Modulo>
113where
114    Modulo::Element: FFElem + Copy,
115{
116    fn mul_assign(&mut self, other: Self) {
117        *self = self.clone() * other;
118    }
119}
120
121impl<Modulo: FF> Add for Matrix<Modulo>
122where
123    Modulo::Element: FFElem + Copy,
124{
125    type Output = Self;
126    fn add(mut self, other: Self) -> Self {
127        self += other;
128        self
129    }
130}
131
132impl<Modulo: FF> Sub for Matrix<Modulo>
133where
134    Modulo::Element: FFElem + Copy,
135{
136    type Output = Self;
137    fn sub(mut self, other: Self) -> Self {
138        self -= other;
139        self
140    }
141}
142
143impl<Modulo: FF> Mul for Matrix<Modulo>
144where
145    Modulo::Element: FFElem + Copy,
146{
147    type Output = Self;
148    fn mul(self, other: Self) -> Self {
149        assert!(self.w == other.h);
150
151        let n = self.h;
152        let l = other.w;
153        let other = other.transpose();
154        let mut ret = Self::new(n, l, self.modulo);
155
156        for (r, r2) in ret.data.iter_mut().zip(self.data.iter()) {
157            for (x, c) in r.iter_mut().zip(other.data.iter()) {
158                for (y, z) in r2.iter().zip(c.iter()) {
159                    *x += *y * *z;
160                }
161            }
162        }
163
164        ret
165    }
166}
167
168impl<Modulo: FF> Neg for Matrix<Modulo>
169where
170    Modulo::Element: FFElem + Copy,
171{
172    type Output = Self;
173    fn neg(mut self) -> Self {
174        self.data.iter_mut().for_each(|r| {
175            r.iter_mut().for_each(|x| {
176                *x = -*x;
177            })
178        });
179        self
180    }
181}
182
183impl<Modulo: FF> Index<usize> for Matrix<Modulo>
184where
185    Modulo::Element: FFElem + Copy,
186{
187    type Output = [Modulo::Element];
188    fn index(&self, i: usize) -> &Self::Output {
189        &self.data[i]
190    }
191}