haar_lib/linalg/mod_m/
square_matrix.rs

1//! 正方行列
2use crate::num::ff::*;
3use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub, SubAssign};
4
5/// 正方行列
6#[derive(Clone, PartialEq, Eq)]
7pub struct SquareMatrix<Modulo: FF>
8where
9    Modulo::Element: FFElem,
10{
11    size: usize,
12    modulo: Modulo,
13    data: Vec<Vec<Modulo::Element>>,
14}
15
16impl<Modulo: FF> SquareMatrix<Modulo>
17where
18    Modulo::Element: FFElem + Copy,
19{
20    /// `size`×`size`の零行列を作る。
21    pub fn new(size: usize, modulo: Modulo) -> Self {
22        Self {
23            size,
24            data: vec![vec![modulo.from_u64(0); size]; size],
25            modulo,
26        }
27    }
28
29    /// `size`×`size`の単位行列を作る。
30    pub fn unit(size: usize, modulo: Modulo) -> Self {
31        let mut ret = Self::new(size, modulo.clone());
32        for i in 0..size {
33            ret.data[i][i] = modulo.from_u64(1);
34        }
35        ret
36    }
37
38    /// [`Vec<Vec<u32>>`]から[`SquareMatrix`]を作る。
39    pub fn from_vec_vec_u32(other: Vec<Vec<u32>>, modulo: Modulo) -> Self {
40        let size = other.len();
41        assert!(size > 0);
42        assert!(other.iter().all(|r| r.len() == size));
43
44        let other = other
45            .into_iter()
46            .map(|a| {
47                a.into_iter()
48                    .map(|x| modulo.from_u64(x as u64))
49                    .collect::<Vec<_>>()
50            })
51            .collect();
52
53        Self {
54            size,
55            data: other,
56            modulo,
57        }
58    }
59
60    /// 行列の行数(列数)を返す。
61    pub fn size(&self) -> usize {
62        self.size
63    }
64
65    /// 行列の転置を求める。
66    pub fn transpose(self) -> Self {
67        let mut ret = Self::new(self.size, self.modulo);
68        for i in 0..self.size {
69            for j in 0..self.size {
70                ret.data[j][i] = self.data[i][j];
71            }
72        }
73        ret
74    }
75
76    /// 行列の`p`乗を求める。
77    pub fn pow(self, mut p: u64) -> Self {
78        let mut ret = Self::unit(self.size, self.modulo.clone());
79        let mut a = self;
80
81        while p > 0 {
82            if p & 1 != 0 {
83                ret *= a.clone();
84            }
85            a *= a.clone();
86
87            p >>= 1;
88        }
89
90        ret
91    }
92
93    /// `i`行`j`列の要素への可変参照を返す。
94    pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut Modulo::Element> {
95        let a = self.data.get_mut(i)?;
96        a.get_mut(j)
97    }
98
99    /// 愚直に行列積を求める。
100    ///
101    /// **Time complexity** $O(n^3)$
102    pub fn straight_mul(self, b: Self) -> Self {
103        assert_eq!(self.size, b.size);
104
105        let b = b.transpose();
106        let mut ret = Self::new(self.size, self.modulo);
107
108        for (r, r2) in ret.data.iter_mut().zip(self.data.iter()) {
109            for (x, c) in r.iter_mut().zip(b.data.iter()) {
110                for (y, z) in r2.iter().zip(c.iter()) {
111                    *x += *y * *z;
112                }
113            }
114        }
115
116        ret
117    }
118
119    /// Strassenのアルゴリズムによる行列乗算
120    pub fn strassen_mul(self, b: Self) -> Self {
121        let mut a = self;
122        let n = a.size();
123
124        if n <= 256 {
125            return Self::straight_mul(a, b);
126        }
127
128        let m = n.div_ceil(2);
129
130        let mut a11 = Self::new(m, a.modulo.clone());
131        let mut a12 = Self::new(m, a.modulo.clone());
132        let mut a21 = Self::new(m, a.modulo.clone());
133        let mut a22 = Self::new(m, a.modulo.clone());
134
135        let mut b11 = Self::new(m, a.modulo.clone());
136        let mut b12 = Self::new(m, a.modulo.clone());
137        let mut b21 = Self::new(m, a.modulo.clone());
138        let mut b22 = Self::new(m, a.modulo.clone());
139
140        for i in 0..m {
141            for j in 0..m {
142                a11.data[i][j] = a[i][j];
143                b11.data[i][j] = b[i][j];
144
145                if j + m < n {
146                    a12.data[i][j] = a[i][j + m];
147                    b12.data[i][j] = b[i][j + m];
148                }
149
150                if i + m < n {
151                    a21.data[i][j] = a[i + m][j];
152                    b21.data[i][j] = b[i + m][j];
153                }
154
155                if i + m < n && j + m < n {
156                    a22.data[i][j] = a[i + m][j + m];
157                    b22.data[i][j] = b[i + m][j + m];
158                }
159            }
160        }
161
162        let p1 = Self::strassen_mul(a11.clone() + a22.clone(), b11.clone() + b22.clone());
163        let p2 = Self::strassen_mul(a21.clone() + a22.clone(), b11.clone());
164        let p3 = Self::strassen_mul(a11.clone(), b12.clone() - b22.clone());
165        let p4 = Self::strassen_mul(a22.clone(), b21.clone() - b11.clone());
166        let p5 = Self::strassen_mul(a11.clone() + a12.clone(), b22.clone());
167        let p6 = Self::strassen_mul(a21 - a11, b11 + b12);
168        let p7 = Self::strassen_mul(a12 - a22, b21 + b22);
169
170        let c11 = p1.clone() + p4.clone() - p5.clone() + p7;
171        let c12 = p3.clone() + p5;
172        let c21 = p2.clone() + p4;
173        let c22 = p1 + p3 - p2 + p6;
174
175        for i in 0..m {
176            for j in 0..m {
177                a.data[i][j] = c11[i][j];
178                if j + m < n {
179                    a.data[i][j + m] = c12[i][j];
180                }
181                if i + m < n {
182                    a.data[i + m][j] = c21[i][j];
183                }
184                if i + m < n && j + m < n {
185                    a.data[i + m][j + m] = c22[i][j];
186                }
187            }
188        }
189
190        a
191    }
192}
193
194impl<Modulo: FF> From<SquareMatrix<Modulo>> for Vec<Vec<Modulo::Element>> {
195    fn from(value: SquareMatrix<Modulo>) -> Self {
196        value.data
197    }
198}
199
200impl<Modulo: FF> AsRef<[Vec<Modulo::Element>]> for SquareMatrix<Modulo> {
201    fn as_ref(&self) -> &[Vec<Modulo::Element>] {
202        &self.data
203    }
204}
205
206impl<Modulo: FF> Add for SquareMatrix<Modulo>
207where
208    Modulo::Element: FFElem,
209{
210    type Output = Self;
211    fn add(mut self, other: Self) -> Self {
212        assert_eq!(self.size, other.size);
213        for (a, b) in self.data.iter_mut().zip(other.data.into_iter()) {
214            for (x, y) in a.iter_mut().zip(b.into_iter()) {
215                *x += y;
216            }
217        }
218        self
219    }
220}
221
222impl<Modulo: FF> Sub for SquareMatrix<Modulo>
223where
224    Modulo::Element: FFElem,
225{
226    type Output = Self;
227    fn sub(mut self, other: Self) -> Self {
228        assert_eq!(self.size, other.size);
229        for (a, b) in self.data.iter_mut().zip(other.data.into_iter()) {
230            for (x, y) in a.iter_mut().zip(b.into_iter()) {
231                *x -= y;
232            }
233        }
234        self
235    }
236}
237
238impl<Modulo: FF> Mul for SquareMatrix<Modulo>
239where
240    Modulo::Element: FFElem + Copy,
241{
242    type Output = Self;
243    fn mul(self, other: Self) -> Self {
244        self.strassen_mul(other)
245    }
246}
247
248impl<Modulo: FF> AddAssign for SquareMatrix<Modulo>
249where
250    Modulo::Element: FFElem + Copy,
251{
252    fn add_assign(&mut self, other: Self) {
253        *self = self.clone() + other;
254    }
255}
256
257impl<Modulo: FF> SubAssign for SquareMatrix<Modulo>
258where
259    Modulo::Element: FFElem + Copy,
260{
261    fn sub_assign(&mut self, other: Self) {
262        *self = self.clone() - other;
263    }
264}
265
266impl<Modulo: FF> MulAssign for SquareMatrix<Modulo>
267where
268    Modulo::Element: FFElem + Copy,
269{
270    fn mul_assign(&mut self, other: Self) {
271        *self = self.clone() * other;
272    }
273}
274
275impl<Modulo: FF> Neg for SquareMatrix<Modulo>
276where
277    Modulo::Element: FFElem + Copy,
278{
279    type Output = Self;
280    fn neg(mut self) -> Self {
281        self.data.iter_mut().for_each(|r| {
282            r.iter_mut().for_each(|x| {
283                *x = -*x;
284            })
285        });
286        self
287    }
288}
289
290impl<Modulo: FF> Index<usize> for SquareMatrix<Modulo>
291where
292    Modulo::Element: FFElem,
293{
294    type Output = [Modulo::Element];
295    fn index(&self, i: usize) -> &Self::Output {
296        &self.data[i]
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use rand::Rng;
304
305    use crate::num::const_modint::*;
306
307    #[test]
308    fn test() {
309        let mut rng = rand::thread_rng();
310        let modulo = ConstModIntBuilder::<1000000007>;
311
312        let size = 300;
313
314        let mut a = vec![vec![0; size]; size];
315        let mut b = vec![vec![0; size]; size];
316
317        for i in 0..size {
318            for j in 0..size {
319                a[i][j] = rng.gen::<u32>();
320                b[i][j] = rng.gen::<u32>();
321            }
322        }
323
324        let a = SquareMatrix::from_vec_vec_u32(a, modulo);
325        let b = SquareMatrix::from_vec_vec_u32(b, modulo);
326
327        assert!(a.clone().straight_mul(b.clone()) == a.strassen_mul(b));
328    }
329
330    #[test]
331    #[ignore]
332    fn benchmark() {
333        use crate::get_time;
334
335        let mut rng = rand::thread_rng();
336        let modulo = ConstModIntBuilder::<1000000007>;
337
338        let mut straight = vec![];
339        let mut strassen = vec![];
340
341        for &size in &[1, 10, 100, 300, 500] {
342            let mut a = vec![vec![0; size]; size];
343            let mut b = vec![vec![0; size]; size];
344
345            for i in 0..size {
346                for j in 0..size {
347                    a[i][j] = rng.gen::<u32>();
348                    b[i][j] = rng.gen::<u32>();
349                }
350            }
351
352            let a = SquareMatrix::from_vec_vec_u32(a, modulo);
353            let b = SquareMatrix::from_vec_vec_u32(b, modulo);
354
355            straight.push(get_time!({
356                a.clone().straight_mul(b.clone());
357            }));
358
359            strassen.push(get_time!({
360                a.clone().strassen_mul(b.clone());
361            }));
362        }
363
364        dbg!(straight, strassen);
365    }
366}