haar_lib/math/convolution/
conv_xor.rs

1//! 添字XOR畳み込み
2use std::ops::{Add, Sub};
3
4use crate::num::ff::*;
5
6/// $h_{i \oplus j} = \sum f_i g_j$を満たす$h$を求める。
7///
8/// # Requirements
9/// `f.len()` = `g.len()`
10pub fn convolution_xor<Modulo: FF>(
11    mut f: Vec<Modulo::Element>,
12    mut g: Vec<Modulo::Element>,
13    modulo: Modulo,
14) -> Vec<Modulo::Element>
15where
16    Modulo::Element: Copy + FFElem,
17{
18    assert_eq!(f.len(), g.len());
19
20    fwt(&mut f);
21    fwt(&mut g);
22
23    for (x, y) in f.iter_mut().zip(g.into_iter()) {
24        *x *= y;
25    }
26
27    fwt(&mut f);
28
29    let t = modulo.frac(1, f.len() as i64);
30
31    for x in f.iter_mut() {
32        *x *= t;
33    }
34
35    f
36}
37
38fn fwt<T>(f: &mut [T])
39where
40    T: Copy + Add<Output = T> + Sub<Output = T>,
41{
42    let n = f.len();
43    assert!(n.is_power_of_two());
44    for i in (0..).map(|i| 1 << i).take_while(|&x| x < n) {
45        for j in 0..n {
46            if j & i == 0 {
47                let x = f[j];
48                let y = f[j | i];
49                f[j] = x + y;
50                f[j | i] = x - y;
51            }
52        }
53    }
54}