haar_lib/linalg/mod_m/
square_matrix.rs1use crate::num::ff::*;
3use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub, SubAssign};
4
5#[derive(Clone, PartialEq, Eq)]
7pub struct SquareMatrix<Modulo: FF>
8where
9 Modulo::Element: FFElem,
10{
11 size: usize,
12 modulo: Modulo,
13 data: Vec<Vec<Modulo::Element>>,
14}
15
16impl<Modulo: FF> SquareMatrix<Modulo>
17where
18 Modulo::Element: FFElem + Copy,
19{
20 pub fn new(size: usize, modulo: Modulo) -> Self {
22 Self {
23 size,
24 data: vec![vec![modulo.from_u64(0); size]; size],
25 modulo,
26 }
27 }
28
29 pub fn unit(size: usize, modulo: Modulo) -> Self {
31 let mut ret = Self::new(size, modulo.clone());
32 for i in 0..size {
33 ret.data[i][i] = modulo.from_u64(1);
34 }
35 ret
36 }
37
38 pub fn from_vec_vec_u32(other: Vec<Vec<u32>>, modulo: Modulo) -> Self {
40 let size = other.len();
41 assert!(size > 0);
42 assert!(other.iter().all(|r| r.len() == size));
43
44 let other = other
45 .into_iter()
46 .map(|a| {
47 a.into_iter()
48 .map(|x| modulo.from_u64(x as u64))
49 .collect::<Vec<_>>()
50 })
51 .collect();
52
53 Self {
54 size,
55 data: other,
56 modulo,
57 }
58 }
59
60 pub fn size(&self) -> usize {
62 self.size
63 }
64
65 pub fn transpose(self) -> Self {
67 let mut ret = Self::new(self.size, self.modulo);
68 for i in 0..self.size {
69 for j in 0..self.size {
70 ret.data[j][i] = self.data[i][j];
71 }
72 }
73 ret
74 }
75
76 pub fn pow(self, mut p: u64) -> Self {
78 let mut ret = Self::unit(self.size, self.modulo.clone());
79 let mut a = self;
80
81 while p > 0 {
82 if p & 1 != 0 {
83 ret *= a.clone();
84 }
85 a *= a.clone();
86
87 p >>= 1;
88 }
89
90 ret
91 }
92
93 pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut Modulo::Element> {
95 let a = self.data.get_mut(i)?;
96 a.get_mut(j)
97 }
98
99 pub fn straight_mul(self, b: Self) -> Self {
103 assert_eq!(self.size, b.size);
104
105 let b = b.transpose();
106 let mut ret = Self::new(self.size, self.modulo);
107
108 for (r, r2) in ret.data.iter_mut().zip(self.data.iter()) {
109 for (x, c) in r.iter_mut().zip(b.data.iter()) {
110 for (y, z) in r2.iter().zip(c.iter()) {
111 *x += *y * *z;
112 }
113 }
114 }
115
116 ret
117 }
118
119 pub fn strassen_mul(self, b: Self) -> Self {
121 let mut a = self;
122 let n = a.size();
123
124 if n <= 256 {
125 return Self::straight_mul(a, b);
126 }
127
128 let m = n.div_ceil(2);
129
130 let mut a11 = Self::new(m, a.modulo.clone());
131 let mut a12 = Self::new(m, a.modulo.clone());
132 let mut a21 = Self::new(m, a.modulo.clone());
133 let mut a22 = Self::new(m, a.modulo.clone());
134
135 let mut b11 = Self::new(m, a.modulo.clone());
136 let mut b12 = Self::new(m, a.modulo.clone());
137 let mut b21 = Self::new(m, a.modulo.clone());
138 let mut b22 = Self::new(m, a.modulo.clone());
139
140 for i in 0..m {
141 for j in 0..m {
142 a11.data[i][j] = a[i][j];
143 b11.data[i][j] = b[i][j];
144
145 if j + m < n {
146 a12.data[i][j] = a[i][j + m];
147 b12.data[i][j] = b[i][j + m];
148 }
149
150 if i + m < n {
151 a21.data[i][j] = a[i + m][j];
152 b21.data[i][j] = b[i + m][j];
153 }
154
155 if i + m < n && j + m < n {
156 a22.data[i][j] = a[i + m][j + m];
157 b22.data[i][j] = b[i + m][j + m];
158 }
159 }
160 }
161
162 let p1 = Self::strassen_mul(a11.clone() + a22.clone(), b11.clone() + b22.clone());
163 let p2 = Self::strassen_mul(a21.clone() + a22.clone(), b11.clone());
164 let p3 = Self::strassen_mul(a11.clone(), b12.clone() - b22.clone());
165 let p4 = Self::strassen_mul(a22.clone(), b21.clone() - b11.clone());
166 let p5 = Self::strassen_mul(a11.clone() + a12.clone(), b22.clone());
167 let p6 = Self::strassen_mul(a21 - a11, b11 + b12);
168 let p7 = Self::strassen_mul(a12 - a22, b21 + b22);
169
170 let c11 = p1.clone() + p4.clone() - p5.clone() + p7;
171 let c12 = p3.clone() + p5;
172 let c21 = p2.clone() + p4;
173 let c22 = p1 + p3 - p2 + p6;
174
175 for i in 0..m {
176 for j in 0..m {
177 a.data[i][j] = c11[i][j];
178 if j + m < n {
179 a.data[i][j + m] = c12[i][j];
180 }
181 if i + m < n {
182 a.data[i + m][j] = c21[i][j];
183 }
184 if i + m < n && j + m < n {
185 a.data[i + m][j + m] = c22[i][j];
186 }
187 }
188 }
189
190 a
191 }
192}
193
194impl<Modulo: FF> From<SquareMatrix<Modulo>> for Vec<Vec<Modulo::Element>> {
195 fn from(value: SquareMatrix<Modulo>) -> Self {
196 value.data
197 }
198}
199
200impl<Modulo: FF> AsRef<[Vec<Modulo::Element>]> for SquareMatrix<Modulo> {
201 fn as_ref(&self) -> &[Vec<Modulo::Element>] {
202 &self.data
203 }
204}
205
206impl<Modulo: FF> Add for SquareMatrix<Modulo>
207where
208 Modulo::Element: FFElem,
209{
210 type Output = Self;
211 fn add(mut self, other: Self) -> Self {
212 assert_eq!(self.size, other.size);
213 for (a, b) in self.data.iter_mut().zip(other.data.into_iter()) {
214 for (x, y) in a.iter_mut().zip(b.into_iter()) {
215 *x += y;
216 }
217 }
218 self
219 }
220}
221
222impl<Modulo: FF> Sub for SquareMatrix<Modulo>
223where
224 Modulo::Element: FFElem,
225{
226 type Output = Self;
227 fn sub(mut self, other: Self) -> Self {
228 assert_eq!(self.size, other.size);
229 for (a, b) in self.data.iter_mut().zip(other.data.into_iter()) {
230 for (x, y) in a.iter_mut().zip(b.into_iter()) {
231 *x -= y;
232 }
233 }
234 self
235 }
236}
237
238impl<Modulo: FF> Mul for SquareMatrix<Modulo>
239where
240 Modulo::Element: FFElem + Copy,
241{
242 type Output = Self;
243 fn mul(self, other: Self) -> Self {
244 self.strassen_mul(other)
245 }
246}
247
248impl<Modulo: FF> AddAssign for SquareMatrix<Modulo>
249where
250 Modulo::Element: FFElem + Copy,
251{
252 fn add_assign(&mut self, other: Self) {
253 *self = self.clone() + other;
254 }
255}
256
257impl<Modulo: FF> SubAssign for SquareMatrix<Modulo>
258where
259 Modulo::Element: FFElem + Copy,
260{
261 fn sub_assign(&mut self, other: Self) {
262 *self = self.clone() - other;
263 }
264}
265
266impl<Modulo: FF> MulAssign for SquareMatrix<Modulo>
267where
268 Modulo::Element: FFElem + Copy,
269{
270 fn mul_assign(&mut self, other: Self) {
271 *self = self.clone() * other;
272 }
273}
274
275impl<Modulo: FF> Neg for SquareMatrix<Modulo>
276where
277 Modulo::Element: FFElem + Copy,
278{
279 type Output = Self;
280 fn neg(mut self) -> Self {
281 self.data.iter_mut().for_each(|r| {
282 r.iter_mut().for_each(|x| {
283 *x = -*x;
284 })
285 });
286 self
287 }
288}
289
290impl<Modulo: FF> Index<usize> for SquareMatrix<Modulo>
291where
292 Modulo::Element: FFElem,
293{
294 type Output = [Modulo::Element];
295 fn index(&self, i: usize) -> &Self::Output {
296 &self.data[i]
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use rand::Rng;
304
305 use crate::num::const_modint::*;
306
307 #[test]
308 fn test() {
309 let mut rng = rand::thread_rng();
310 let modulo = ConstModIntBuilder::<1000000007>;
311
312 let size = 300;
313
314 let mut a = vec![vec![0; size]; size];
315 let mut b = vec![vec![0; size]; size];
316
317 for i in 0..size {
318 for j in 0..size {
319 a[i][j] = rng.gen::<u32>();
320 b[i][j] = rng.gen::<u32>();
321 }
322 }
323
324 let a = SquareMatrix::from_vec_vec_u32(a, modulo);
325 let b = SquareMatrix::from_vec_vec_u32(b, modulo);
326
327 assert!(a.clone().straight_mul(b.clone()) == a.strassen_mul(b));
328 }
329
330 #[test]
331 #[ignore]
332 fn benchmark() {
333 use crate::get_time;
334
335 let mut rng = rand::thread_rng();
336 let modulo = ConstModIntBuilder::<1000000007>;
337
338 let mut straight = vec![];
339 let mut strassen = vec![];
340
341 for &size in &[1, 10, 100, 300, 500] {
342 let mut a = vec![vec![0; size]; size];
343 let mut b = vec![vec![0; size]; size];
344
345 for i in 0..size {
346 for j in 0..size {
347 a[i][j] = rng.gen::<u32>();
348 b[i][j] = rng.gen::<u32>();
349 }
350 }
351
352 let a = SquareMatrix::from_vec_vec_u32(a, modulo);
353 let b = SquareMatrix::from_vec_vec_u32(b, modulo);
354
355 straight.push(get_time!({
356 a.clone().straight_mul(b.clone());
357 }));
358
359 strassen.push(get_time!({
360 a.clone().strassen_mul(b.clone());
361 }));
362 }
363
364 dbg!(straight, strassen);
365 }
366}