1pub use crate::linalg::traits::*;
3use crate::{algebra::semiring::*, impl_ops};
4use std::ops::{Index, Neg};
5
6#[derive(Clone, Debug, PartialEq, Eq)]
8pub struct MatrixOnSemiring<R: Semiring> {
9 h: usize,
10 w: usize,
11 semiring: R,
12 data: Vec<Vec<R::Element>>,
13}
14
15impl<R: Semiring> Matrix for MatrixOnSemiring<R> {
16 fn width(&self) -> usize {
17 self.w
18 }
19 fn height(&self) -> usize {
20 self.h
21 }
22}
23
24impl<R: Semiring> MatrixTranspose for MatrixOnSemiring<R>
25where
26 R::Element: Copy,
27{
28 type Output = Self;
29 fn transpose(self) -> Self::Output {
30 let mut ret = Self::zero(self.semiring, self.w, self.h);
31 for i in 0..self.h {
32 for j in 0..self.w {
33 ret.data[j][i] = self.data[i][j];
34 }
35 }
36 ret
37 }
38}
39
40impl<R: Semiring> MatrixOnSemiring<R>
41where
42 R::Element: Copy,
43{
44 pub fn zero(semiring: R, h: usize, w: usize) -> Self {
46 Self {
47 h,
48 w,
49 data: vec![vec![semiring.zero(); w]; h],
50 semiring,
51 }
52 }
53
54 pub fn unit(semiring: R, size: usize) -> Self {
56 let one = semiring.one();
57 let mut ret = Self::zero(semiring, size, size);
58 for i in 0..size {
59 ret.data[i][i] = one;
60 }
61 ret
62 }
63
64 pub fn from_vec<T>(semiring: R, a: Vec<Vec<T>>) -> Self
66 where
67 T: Into<R::Element>,
68 {
69 let h = a.len();
70 assert!(h > 0);
71 let w = a[0].len();
72 assert!(a.iter().all(|r| r.len() == w));
73
74 let data = a
75 .into_iter()
76 .map(|r| r.into_iter().map(T::into).collect())
77 .collect();
78
79 Self {
80 semiring,
81 data,
82 h,
83 w,
84 }
85 }
86
87 pub fn times(mut self, n: u64) -> Self {
89 self.data
90 .iter_mut()
91 .for_each(|r| r.iter_mut().for_each(|a| *a = self.semiring.times(*a, n)));
92 self
93 }
94
95 pub fn get(&self, i: usize, j: usize) -> Option<&R::Element> {
97 let a = self.data.get(i)?;
98 a.get(j)
99 }
100
101 pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut R::Element> {
103 let a = self.data.get_mut(i)?;
104 a.get_mut(j)
105 }
106}
107
108impl<R: Semiring + Clone> MatrixOnSemiring<R>
109where
110 R::Element: Copy,
111{
112 pub fn pow(self, mut p: u64) -> Option<Self> {
114 self.is_square().then(|| {
115 let size = self.w;
116 let mut ret = Self::unit(self.semiring.clone(), size);
117 let mut a = self;
118
119 while p > 0 {
120 if p & 1 != 0 {
121 ret *= a.clone();
122 }
123 a *= a.clone();
124
125 p >>= 1;
126 }
127
128 ret
129 })
130 }
131
132 fn straight_mul(self, rhs: Self) -> Self {
136 assert_eq!(self.w, rhs.h);
137
138 let n = self.h;
139 let l = rhs.w;
140 let rhs = rhs.transpose();
141 let mut ret = Self::zero(self.semiring.clone(), n, l);
142 let s = &self.semiring;
143
144 for (r, r2) in ret.data.iter_mut().zip(self.data.iter()) {
145 for (x, c) in r.iter_mut().zip(rhs.data.iter()) {
146 for (y, z) in r2.iter().zip(c.iter()) {
147 *x = s.add(*x, s.mul(*y, *z));
148 }
149 }
150 }
151
152 ret
153 }
154}
155
156impl<R: Ring + Clone> MatrixOnSemiring<R>
157where
158 R::Element: Copy,
159{
160 pub fn strassen_mul(self, b: Self) -> Self {
162 let mut a = self;
163
164 assert!(a.is_square() && b.is_square() && a.size() == b.size());
165
166 let n = a.width();
167
168 if n <= 256 {
169 return Self::straight_mul(a, b);
170 }
171
172 let m = n.div_ceil(2);
173
174 let mut a11 = Self::zero(a.semiring.clone(), m, m);
175 let mut a12 = Self::zero(a.semiring.clone(), m, m);
176 let mut a21 = Self::zero(a.semiring.clone(), m, m);
177 let mut a22 = Self::zero(a.semiring.clone(), m, m);
178 let mut b11 = Self::zero(a.semiring.clone(), m, m);
179 let mut b12 = Self::zero(a.semiring.clone(), m, m);
180 let mut b21 = Self::zero(a.semiring.clone(), m, m);
181 let mut b22 = Self::zero(a.semiring.clone(), m, m);
182
183 for i in 0..m {
184 for j in 0..m {
185 a11.data[i][j] = a[i][j];
186 b11.data[i][j] = b[i][j];
187
188 if j + m < n {
189 a12.data[i][j] = a[i][j + m];
190 b12.data[i][j] = b[i][j + m];
191 }
192
193 if i + m < n {
194 a21.data[i][j] = a[i + m][j];
195 b21.data[i][j] = b[i + m][j];
196 }
197
198 if i + m < n && j + m < n {
199 a22.data[i][j] = a[i + m][j + m];
200 b22.data[i][j] = b[i + m][j + m];
201 }
202 }
203 }
204
205 let p1 = Self::strassen_mul(a11.clone() + a22.clone(), b11.clone() + b22.clone());
206 let p2 = Self::strassen_mul(a21.clone() + a22.clone(), b11.clone());
207 let p3 = Self::strassen_mul(a11.clone(), b12.clone() - b22.clone());
208 let p4 = Self::strassen_mul(a22.clone(), b21.clone() - b11.clone());
209 let p5 = Self::strassen_mul(a11.clone() + a12.clone(), b22.clone());
210 let p6 = Self::strassen_mul(a21 - a11, b11 + b12);
211 let p7 = Self::strassen_mul(a12 - a22, b21 + b22);
212
213 let c11 = p1.clone() + p4.clone() - p5.clone() + p7;
214 let c12 = p3.clone() + p5;
215 let c21 = p2.clone() + p4;
216 let c22 = p1 + p3 - p2 + p6;
217
218 for i in 0..m {
219 for j in 0..m {
220 a.data[i][j] = c11[i][j];
221 if j + m < n {
222 a.data[i][j + m] = c12[i][j];
223 }
224 if i + m < n {
225 a.data[i + m][j] = c21[i][j];
226 }
227 if i + m < n && j + m < n {
228 a.data[i + m][j + m] = c22[i][j];
229 }
230 }
231 }
232
233 a
234 }
235}
236
237impl<R: Semiring> TryAdd for MatrixOnSemiring<R>
238where
239 R::Element: Copy,
240{
241 type Output = Self;
242 fn try_add(mut self, rhs: Self) -> Option<Self::Output> {
243 (self.size() == rhs.size()).then(|| {
244 for i in 0..self.h {
245 for j in 0..self.w {
246 self.data[i][j] = self.semiring.add(self.data[i][j], rhs.data[i][j]);
247 }
248 }
249 self
250 })
251 }
252}
253
254impl<R: Ring> TrySub for MatrixOnSemiring<R>
255where
256 R::Element: Copy,
257{
258 type Output = Self;
259 fn try_sub(mut self, rhs: Self) -> Option<Self::Output> {
260 (self.size() == rhs.size()).then(|| {
261 for i in 0..self.h {
262 for j in 0..self.w {
263 self.data[i][j] = self.semiring.sub(self.data[i][j], rhs.data[i][j]);
264 }
265 }
266 self
267 })
268 }
269}
270
271impl<R: Semiring + Clone> TryMul for MatrixOnSemiring<R>
272where
273 R::Element: Copy,
274{
275 type Output = Self;
276 fn try_mul(self, rhs: Self) -> Option<Self::Output> {
277 if self.w != rhs.h {
278 None
279 } else {
280 Some(self.straight_mul(rhs))
281 }
282 }
283}
284
285impl_ops!({R: Semiring + Clone} AddAssign for MatrixOnSemiring<R> where {R::Element: Copy}, |x: &mut Self, y: Self| *x = x.clone().try_add(y).unwrap());
286impl_ops!({R: Ring + Clone} SubAssign for MatrixOnSemiring<R> where {R::Element: Copy}, |x: &mut Self, y: Self| *x = x.clone().try_sub(y).unwrap());
287impl_ops!({R: Semiring + Clone} MulAssign for MatrixOnSemiring<R> where {R::Element: Copy}, |x: &mut Self, y: Self| *x = x.clone().try_mul(y).unwrap());
288
289impl_ops!({R: Semiring + Clone} Add for MatrixOnSemiring<R> where {R::Element: Copy}, |x: Self, y| x.try_add(y).unwrap());
290impl_ops!({R: Ring + Clone} Sub for MatrixOnSemiring<R> where {R::Element: Copy}, |x: Self, y| x.try_sub(y).unwrap());
291impl_ops!({R: Semiring + Clone} Mul for MatrixOnSemiring<R> where {R::Element: Copy}, |x: Self, y| x.try_mul(y).unwrap());
292
293impl<R: Ring> Neg for MatrixOnSemiring<R>
294where
295 R::Element: Copy,
296{
297 type Output = Self;
298 fn neg(mut self) -> Self {
299 self.data.iter_mut().for_each(|r| {
300 r.iter_mut().for_each(|x| {
301 *x = self.semiring.neg(*x);
302 })
303 });
304 self
305 }
306}
307
308impl<R: Semiring> Index<usize> for MatrixOnSemiring<R> {
309 type Output = [R::Element];
310 fn index(&self, i: usize) -> &Self::Output {
311 &self.data[i]
312 }
313}
314
315impl<R: Semiring> From<MatrixOnSemiring<R>> for Vec<Vec<R::Element>> {
316 fn from(value: MatrixOnSemiring<R>) -> Self {
317 value.data
318 }
319}
320
321impl<R: Semiring> AsRef<[Vec<R::Element>]> for MatrixOnSemiring<R> {
322 fn as_ref(&self) -> &[Vec<R::Element>] {
323 &self.data
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use rand::Rng;
331
332 use crate::{
333 algebra::semiring::{add_mul_mod::AddMulMod, xor_and::XorAnd},
334 math::prime_mod::Prime,
335 num::const_modint::*,
336 };
337
338 #[test]
339 fn test() {
340 let mut rng = rand::thread_rng();
341 let modulo = ConstModIntBuilder::<Prime<1000000007>>::new();
342 let ring = AddMulMod(modulo);
343
344 let size = 300;
345
346 let mut a = MatrixOnSemiring::zero(ring, size, size);
347 let mut b = MatrixOnSemiring::zero(ring, size, size);
348
349 for i in 0..size {
350 for j in 0..size {
351 *a.get_mut(i, j).unwrap() = modulo.from_u64(rng.gen::<u32>() as u64);
352 *b.get_mut(i, j).unwrap() = modulo.from_u64(rng.gen::<u32>() as u64);
353 }
354 }
355
356 assert!(a.clone().straight_mul(b.clone()) == a.strassen_mul(b));
357 }
358
359 #[test]
360 fn test_xor_and() {
361 let mut rng = rand::thread_rng();
362 let ring = XorAnd::<u64>::new();
363
364 let size = 300;
365
366 let mut a = MatrixOnSemiring::zero(ring, size, size);
367 let mut b = MatrixOnSemiring::zero(ring, size, size);
368
369 for i in 0..size {
370 for j in 0..size {
371 *a.get_mut(i, j).unwrap() = rng.gen::<u64>();
372 *b.get_mut(i, j).unwrap() = rng.gen::<u64>();
373 }
374 }
375
376 assert!(a.clone().straight_mul(b.clone()) == a.strassen_mul(b));
377 }
378
379 #[test]
380 #[ignore]
381 fn benchmark() {
382 use crate::get_time;
383
384 let mut rng = rand::thread_rng();
385 let modulo = ConstModIntBuilder::<Prime<1000000007>>::new();
386 let ring = AddMulMod(modulo);
387
388 let mut straight = vec![];
389 let mut strassen = vec![];
390
391 for &size in &[1, 10, 100, 300, 500] {
392 let mut a = MatrixOnSemiring::zero(ring, size, size);
393 let mut b = MatrixOnSemiring::zero(ring, size, size);
394
395 for i in 0..size {
396 for j in 0..size {
397 *a.get_mut(i, j).unwrap() = modulo.from_u64(rng.gen::<u32>() as u64);
398 *b.get_mut(i, j).unwrap() = modulo.from_u64(rng.gen::<u32>() as u64);
399 }
400 }
401
402 straight.push(get_time!({
403 a.clone().straight_mul(b.clone());
404 }));
405
406 strassen.push(get_time!({
407 a.clone().strassen_mul(b.clone());
408 }));
409 }
410
411 dbg!(straight, strassen);
412 }
413}