haar_lib/linalg/mod_2/
matrix.rs1use std::ops::Index;
3use std::ops::Mul;
4
5use crate::ds::bitset::Bitset;
6
7#[derive(Clone)]
9pub struct MatrixMod2 {
10 h: usize,
11 w: usize,
12 data: Vec<Bitset>,
13}
14
15impl MatrixMod2 {
16 pub fn new(h: usize, w: usize) -> Self {
18 Self {
19 h,
20 w,
21 data: vec![Bitset::new(w); h],
22 }
23 }
24
25 pub fn from_vec_bitset(other: Vec<Bitset>) -> Self {
27 let h = other.len();
28 assert!(h > 0);
29 let w = other[0].len();
30 assert!(other.iter().all(|r| r.len() == w));
31
32 Self { h, w, data: other }
33 }
34
35 pub fn transpose(self) -> Self {
37 let mut ret = Self::new(self.w, self.h);
38 for i in 0..self.h {
39 for j in 0..self.w {
40 if self.data[i].test(j) {
41 ret.data[j].flip(i);
42 }
43 }
44 }
45 ret
46 }
47
48 pub fn get(&self, i: usize, j: usize) -> Option<u32> {
50 let a = self.data.get(i)?;
51 (j < a.len()).then(|| a.test(j) as u32)
52 }
53}
54
55impl Mul for MatrixMod2 {
56 type Output = Self;
57
58 fn mul(self, rhs: Self) -> Self::Output {
59 assert_eq!(self.w, rhs.h);
60
61 let n = self.h;
62 let l = rhs.w;
63 let rhs = rhs.transpose();
64
65 let mut ret = Self::new(n, l);
66
67 for (r, r2) in ret.data.iter_mut().zip(self.data.iter()) {
68 for (i, c) in rhs.data.chunks(Bitset::B_SIZE).enumerate() {
69 let mut a = 0;
70
71 for (j, x) in c.iter().enumerate() {
72 let t = r2.and_count_ones(x) & 1;
73
74 if t != 0 {
75 a |= 1 << j;
76 }
77 }
78
79 r.data[i] = a;
80 }
81 }
82
83 ret
84 }
85}
86
87impl Index<usize> for MatrixMod2 {
88 type Output = Bitset;
89 fn index(&self, i: usize) -> &Self::Output {
90 &self.data[i]
91 }
92}