haar_lib/linalg/
matrix.rs

1//! 半環上の行列
2pub use crate::linalg::traits::*;
3use crate::{algebra::semiring::*, impl_ops};
4use std::ops::{Index, Neg};
5
6/// 半環上の行列
7#[derive(Clone, Debug, PartialEq, Eq)]
8pub struct MatrixOnSemiring<R: Semiring> {
9    h: usize,
10    w: usize,
11    semiring: R,
12    data: Vec<Vec<R::Element>>,
13}
14
15impl<R: Semiring> Matrix for MatrixOnSemiring<R> {
16    fn width(&self) -> usize {
17        self.w
18    }
19    fn height(&self) -> usize {
20        self.h
21    }
22}
23
24impl<R: Semiring> MatrixTranspose for MatrixOnSemiring<R>
25where
26    R::Element: Copy,
27{
28    type Output = Self;
29    fn transpose(self) -> Self::Output {
30        let mut ret = Self::zero(self.semiring, self.w, self.h);
31        for i in 0..self.h {
32            for j in 0..self.w {
33                ret.data[j][i] = self.data[i][j];
34            }
35        }
36        ret
37    }
38}
39
40impl<R: Semiring> MatrixOnSemiring<R>
41where
42    R::Element: Copy,
43{
44    /// `h`×`w`の零行列を作る。
45    pub fn zero(semiring: R, h: usize, w: usize) -> Self {
46        Self {
47            h,
48            w,
49            data: vec![vec![semiring.zero(); w]; h],
50            semiring,
51        }
52    }
53
54    /// `size`×`size`の単位行列を作る。
55    pub fn unit(semiring: R, size: usize) -> Self {
56        let one = semiring.one();
57        let mut ret = Self::zero(semiring, size, size);
58        for i in 0..size {
59            ret.data[i][i] = one;
60        }
61        ret
62    }
63
64    /// `Vec<Vec<T>>`から`MatrixOnRing`を作る。
65    pub fn from_vec<T>(semiring: R, a: Vec<Vec<T>>) -> Self
66    where
67        T: Into<R::Element>,
68    {
69        let h = a.len();
70        assert!(h > 0);
71        let w = a[0].len();
72        assert!(a.iter().all(|r| r.len() == w));
73
74        let data = a
75            .into_iter()
76            .map(|r| r.into_iter().map(T::into).collect())
77            .collect();
78
79        Self {
80            semiring,
81            data,
82            h,
83            w,
84        }
85    }
86
87    /// `self`を`n`回足した行列を求める。
88    pub fn times(mut self, n: u64) -> Self {
89        self.data
90            .iter_mut()
91            .for_each(|r| r.iter_mut().for_each(|a| *a = self.semiring.times(*a, n)));
92        self
93    }
94
95    /// `i`行`j`列の要素への参照を返す。
96    pub fn get(&self, i: usize, j: usize) -> Option<&R::Element> {
97        let a = self.data.get(i)?;
98        a.get(j)
99    }
100
101    /// `i`行`j`列の要素への可変参照を返す。
102    pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut R::Element> {
103        let a = self.data.get_mut(i)?;
104        a.get_mut(j)
105    }
106}
107
108impl<R: Semiring + Clone> MatrixOnSemiring<R>
109where
110    R::Element: Copy,
111{
112    /// 行列の`p`乗を求める。
113    pub fn pow(self, mut p: u64) -> Option<Self> {
114        self.is_square().then(|| {
115            let size = self.w;
116            let mut ret = Self::unit(self.semiring.clone(), size);
117            let mut a = self;
118
119            while p > 0 {
120                if p & 1 != 0 {
121                    ret *= a.clone();
122                }
123                a *= a.clone();
124
125                p >>= 1;
126            }
127
128            ret
129        })
130    }
131
132    /// 愚直に行列積を求める。
133    ///
134    /// **Time complexity** $O(n^3)$
135    fn straight_mul(self, rhs: Self) -> Self {
136        assert_eq!(self.w, rhs.h);
137
138        let n = self.h;
139        let l = rhs.w;
140        let rhs = rhs.transpose();
141        let mut ret = Self::zero(self.semiring.clone(), n, l);
142        let s = &self.semiring;
143
144        for (r, r2) in ret.data.iter_mut().zip(self.data.iter()) {
145            for (x, c) in r.iter_mut().zip(rhs.data.iter()) {
146                for (y, z) in r2.iter().zip(c.iter()) {
147                    *x = s.add(*x, s.mul(*y, *z));
148                }
149            }
150        }
151
152        ret
153    }
154}
155
156impl<R: Ring + Clone> MatrixOnSemiring<R>
157where
158    R::Element: Copy,
159{
160    /// Strassenのアルゴリズムによる行列乗算
161    pub fn strassen_mul(self, b: Self) -> Self {
162        let mut a = self;
163
164        assert!(a.is_square() && b.is_square() && a.size() == b.size());
165
166        let n = a.width();
167
168        if n <= 256 {
169            return Self::straight_mul(a, b);
170        }
171
172        let m = n.div_ceil(2);
173
174        let mut a11 = Self::zero(a.semiring.clone(), m, m);
175        let mut a12 = Self::zero(a.semiring.clone(), m, m);
176        let mut a21 = Self::zero(a.semiring.clone(), m, m);
177        let mut a22 = Self::zero(a.semiring.clone(), m, m);
178        let mut b11 = Self::zero(a.semiring.clone(), m, m);
179        let mut b12 = Self::zero(a.semiring.clone(), m, m);
180        let mut b21 = Self::zero(a.semiring.clone(), m, m);
181        let mut b22 = Self::zero(a.semiring.clone(), m, m);
182
183        for i in 0..m {
184            for j in 0..m {
185                a11.data[i][j] = a[i][j];
186                b11.data[i][j] = b[i][j];
187
188                if j + m < n {
189                    a12.data[i][j] = a[i][j + m];
190                    b12.data[i][j] = b[i][j + m];
191                }
192
193                if i + m < n {
194                    a21.data[i][j] = a[i + m][j];
195                    b21.data[i][j] = b[i + m][j];
196                }
197
198                if i + m < n && j + m < n {
199                    a22.data[i][j] = a[i + m][j + m];
200                    b22.data[i][j] = b[i + m][j + m];
201                }
202            }
203        }
204
205        let p1 = Self::strassen_mul(a11.clone() + a22.clone(), b11.clone() + b22.clone());
206        let p2 = Self::strassen_mul(a21.clone() + a22.clone(), b11.clone());
207        let p3 = Self::strassen_mul(a11.clone(), b12.clone() - b22.clone());
208        let p4 = Self::strassen_mul(a22.clone(), b21.clone() - b11.clone());
209        let p5 = Self::strassen_mul(a11.clone() + a12.clone(), b22.clone());
210        let p6 = Self::strassen_mul(a21 - a11, b11 + b12);
211        let p7 = Self::strassen_mul(a12 - a22, b21 + b22);
212
213        let c11 = p1.clone() + p4.clone() - p5.clone() + p7;
214        let c12 = p3.clone() + p5;
215        let c21 = p2.clone() + p4;
216        let c22 = p1 + p3 - p2 + p6;
217
218        for i in 0..m {
219            for j in 0..m {
220                a.data[i][j] = c11[i][j];
221                if j + m < n {
222                    a.data[i][j + m] = c12[i][j];
223                }
224                if i + m < n {
225                    a.data[i + m][j] = c21[i][j];
226                }
227                if i + m < n && j + m < n {
228                    a.data[i + m][j + m] = c22[i][j];
229                }
230            }
231        }
232
233        a
234    }
235}
236
237impl<R: Semiring> TryAdd for MatrixOnSemiring<R>
238where
239    R::Element: Copy,
240{
241    type Output = Self;
242    fn try_add(mut self, rhs: Self) -> Option<Self::Output> {
243        (self.size() == rhs.size()).then(|| {
244            for i in 0..self.h {
245                for j in 0..self.w {
246                    self.data[i][j] = self.semiring.add(self.data[i][j], rhs.data[i][j]);
247                }
248            }
249            self
250        })
251    }
252}
253
254impl<R: Ring> TrySub for MatrixOnSemiring<R>
255where
256    R::Element: Copy,
257{
258    type Output = Self;
259    fn try_sub(mut self, rhs: Self) -> Option<Self::Output> {
260        (self.size() == rhs.size()).then(|| {
261            for i in 0..self.h {
262                for j in 0..self.w {
263                    self.data[i][j] = self.semiring.sub(self.data[i][j], rhs.data[i][j]);
264                }
265            }
266            self
267        })
268    }
269}
270
271impl<R: Semiring + Clone> TryMul for MatrixOnSemiring<R>
272where
273    R::Element: Copy,
274{
275    type Output = Self;
276    fn try_mul(self, rhs: Self) -> Option<Self::Output> {
277        if self.w != rhs.h {
278            None
279        } else {
280            Some(self.straight_mul(rhs))
281        }
282    }
283}
284
285impl_ops!({R: Semiring + Clone} AddAssign for MatrixOnSemiring<R> where {R::Element: Copy}, |x: &mut Self, y: Self| *x = x.clone().try_add(y).unwrap());
286impl_ops!({R: Ring + Clone} SubAssign for MatrixOnSemiring<R> where {R::Element: Copy}, |x: &mut Self, y: Self| *x = x.clone().try_sub(y).unwrap());
287impl_ops!({R: Semiring + Clone} MulAssign for MatrixOnSemiring<R> where {R::Element: Copy}, |x: &mut Self, y: Self| *x = x.clone().try_mul(y).unwrap());
288
289impl_ops!({R: Semiring + Clone} Add for MatrixOnSemiring<R> where {R::Element: Copy}, |x: Self, y| x.try_add(y).unwrap());
290impl_ops!({R: Ring + Clone} Sub for MatrixOnSemiring<R> where {R::Element: Copy}, |x: Self, y| x.try_sub(y).unwrap());
291impl_ops!({R: Semiring + Clone} Mul for MatrixOnSemiring<R> where {R::Element: Copy}, |x: Self, y| x.try_mul(y).unwrap());
292
293impl<R: Ring> Neg for MatrixOnSemiring<R>
294where
295    R::Element: Copy,
296{
297    type Output = Self;
298    fn neg(mut self) -> Self {
299        self.data.iter_mut().for_each(|r| {
300            r.iter_mut().for_each(|x| {
301                *x = self.semiring.neg(*x);
302            })
303        });
304        self
305    }
306}
307
308impl<R: Semiring> Index<usize> for MatrixOnSemiring<R> {
309    type Output = [R::Element];
310    fn index(&self, i: usize) -> &Self::Output {
311        &self.data[i]
312    }
313}
314
315impl<R: Semiring> From<MatrixOnSemiring<R>> for Vec<Vec<R::Element>> {
316    fn from(value: MatrixOnSemiring<R>) -> Self {
317        value.data
318    }
319}
320
321impl<R: Semiring> AsRef<[Vec<R::Element>]> for MatrixOnSemiring<R> {
322    fn as_ref(&self) -> &[Vec<R::Element>] {
323        &self.data
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use rand::Rng;
331
332    use crate::{
333        algebra::semiring::{add_mul_mod::AddMulMod, xor_and::XorAnd},
334        math::prime_mod::Prime,
335        num::const_modint::*,
336    };
337
338    #[test]
339    fn test() {
340        let mut rng = rand::thread_rng();
341        let modulo = ConstModIntBuilder::<Prime<1000000007>>::new();
342        let ring = AddMulMod(modulo);
343
344        let size = 300;
345
346        let mut a = MatrixOnSemiring::zero(ring, size, size);
347        let mut b = MatrixOnSemiring::zero(ring, size, size);
348
349        for i in 0..size {
350            for j in 0..size {
351                *a.get_mut(i, j).unwrap() = modulo.from_u64(rng.gen::<u32>() as u64);
352                *b.get_mut(i, j).unwrap() = modulo.from_u64(rng.gen::<u32>() as u64);
353            }
354        }
355
356        assert!(a.clone().straight_mul(b.clone()) == a.strassen_mul(b));
357    }
358
359    #[test]
360    fn test_xor_and() {
361        let mut rng = rand::thread_rng();
362        let ring = XorAnd::<u64>::new();
363
364        let size = 300;
365
366        let mut a = MatrixOnSemiring::zero(ring, size, size);
367        let mut b = MatrixOnSemiring::zero(ring, size, size);
368
369        for i in 0..size {
370            for j in 0..size {
371                *a.get_mut(i, j).unwrap() = rng.gen::<u64>();
372                *b.get_mut(i, j).unwrap() = rng.gen::<u64>();
373            }
374        }
375
376        assert!(a.clone().straight_mul(b.clone()) == a.strassen_mul(b));
377    }
378
379    #[test]
380    #[ignore]
381    fn benchmark() {
382        use crate::get_time;
383
384        let mut rng = rand::thread_rng();
385        let modulo = ConstModIntBuilder::<Prime<1000000007>>::new();
386        let ring = AddMulMod(modulo);
387
388        let mut straight = vec![];
389        let mut strassen = vec![];
390
391        for &size in &[1, 10, 100, 300, 500] {
392            let mut a = MatrixOnSemiring::zero(ring, size, size);
393            let mut b = MatrixOnSemiring::zero(ring, size, size);
394
395            for i in 0..size {
396                for j in 0..size {
397                    *a.get_mut(i, j).unwrap() = modulo.from_u64(rng.gen::<u32>() as u64);
398                    *b.get_mut(i, j).unwrap() = modulo.from_u64(rng.gen::<u32>() as u64);
399                }
400            }
401
402            straight.push(get_time!({
403                a.clone().straight_mul(b.clone());
404            }));
405
406            strassen.push(get_time!({
407                a.clone().strassen_mul(b.clone());
408            }));
409        }
410
411        dbg!(straight, strassen);
412    }
413}