haar_lib/math/convolution/
conv_xor.rs

1//! $\mathtt{a_{i \oplus j}} = \sum \mathtt{f_{i}} * \mathtt{g_{j}}$を満たす`a`を求める。
2use std::ops::{Add, Sub};
3
4use crate::num::ff::*;
5
6/// $\mathtt{a_{i \oplus j}} = \sum \mathtt{f_{i}} * \mathtt{g_{j}}$を満たす`a`を求める。
7pub fn convolution_xor<Modulo: FF>(
8    mut f: Vec<Modulo::Element>,
9    mut g: Vec<Modulo::Element>,
10    modulo: Modulo,
11) -> Vec<Modulo::Element>
12where
13    Modulo::Element: Copy + FFElem,
14{
15    assert_eq!(f.len(), g.len());
16
17    fwt(&mut f);
18    fwt(&mut g);
19
20    for (x, y) in f.iter_mut().zip(g.into_iter()) {
21        *x *= y;
22    }
23
24    fwt(&mut f);
25
26    let t = modulo.frac(1, f.len() as i64);
27
28    for x in f.iter_mut() {
29        *x *= t;
30    }
31
32    f
33}
34
35fn fwt<T>(f: &mut [T])
36where
37    T: Copy + Add<Output = T> + Sub<Output = T>,
38{
39    let n = f.len();
40    assert!(n.is_power_of_two());
41    for i in (0..).map(|i| 1 << i).take_while(|&x| x < n) {
42        for j in 0..n {
43            if j & i == 0 {
44                let x = f[j];
45                let y = f[j | i];
46                f[j] = x + y;
47                f[j | i] = x - y;
48            }
49        }
50    }
51}