haar_lib/math/convolution/
conv_mul_mod2n.rs

1//! 添字積$\pmod {2^N}$畳み込み
2//!
3//! # Problems
4//! - <https://judge.yosupo.jp/problem/mul_mod2n_convolution>
5use crate::math::convolution::ntt::*;
6use crate::math::prime_mod::*;
7use crate::num::const_modint::*;
8
9/// $c_k = \sum_{i \times j = k \pmod {2^N}} a_i b_j$を満たす$c$を求める。
10///
11/// # Requirements
12/// `a.len()` = `b.len()` = $2^N$
13pub fn convolution_mul_mod2n<P: PrimeMod>(
14    a: Vec<ConstModInt<P>>,
15    b: Vec<ConstModInt<P>>,
16) -> Vec<ConstModInt<P>> {
17    let len = a.len();
18    assert_eq!(a.len(), b.len());
19    assert!(len.is_power_of_two());
20
21    let n = len.trailing_zeros() as usize;
22
23    if n <= 1 {
24        let mut ret = vec![0.into(); len];
25
26        for i in 0..len {
27            for j in 0..len {
28                ret[i * j % len] += a[i] * b[j];
29            }
30        }
31
32        return ret;
33    }
34
35    let ntt = NTT::<P>::new();
36    let mask = (1 << n) - 1;
37    let cycle = std::iter::successors(Some(1), |n| Some((n * 5) & mask))
38        .take(len / 4)
39        .collect::<Vec<_>>();
40
41    let mut s: Vec<Vec<Vec<ConstModInt<P>>>> = vec![vec![vec![]; n - 1]; 2];
42    let mut t: Vec<Vec<Vec<ConstModInt<P>>>> = vec![vec![vec![]; n - 1]; 2];
43
44    for i in 0..n - 1 {
45        let k = n - i;
46
47        s[0][i].resize(1 << (k - 2), 0.into());
48        s[1][i].resize(1 << (k - 2), 0.into());
49        t[0][i].resize(1 << (k - 2), 0.into());
50        t[1][i].resize(1 << (k - 2), 0.into());
51
52        let mask2 = (1 << k) - 1;
53
54        for (j, c) in cycle.iter().enumerate().take(1 << (k - 2)) {
55            let r = c & mask2;
56            s[0][i][j] = a[r << i];
57            t[0][i][j] = b[r << i];
58            s[1][i][j] = a[((len - r) & mask2) << i];
59            t[1][i][j] = b[((len - r) & mask2) << i];
60        }
61    }
62
63    let mut ret = vec![0.into(); 1 << n];
64
65    let mut tt0 = vec![vec![]; n];
66    let mut tt1 = vec![vec![]; n];
67
68    for i in (0..n - 1).rev() {
69        let mut s0 = s[0][i].clone();
70        let mut s1 = s[1][i].clone();
71        let slen = s[0][i].len();
72        ntt.ntt(&mut s0);
73        ntt.ntt(&mut s1);
74
75        for j in (0..n - 1).rev() {
76            let tlen = t[0][j].len();
77
78            if slen <= 1 || tlen <= 1 {
79                for x in 0..slen {
80                    for y in 0..tlen {
81                        let g = cycle[(x + y) % cycle.len()] << (i + j);
82                        ret[g & mask] += s[0][i][x] * t[0][j][y] + s[1][i][x] * t[1][j][y];
83                        let mg = if g == 0 { 0 } else { len - g };
84                        ret[mg & mask] += s[1][i][x] * t[0][j][y] + s[0][i][x] * t[1][j][y];
85                    }
86                }
87
88                continue;
89            }
90
91            let w = std::cmp::max(slen, tlen);
92
93            if w > slen {
94                s0 = s[0][i].clone();
95                s1 = s[1][i].clone();
96                s0.resize(w, 0.into());
97                s1.resize(w, 0.into());
98                ntt.ntt(&mut s0);
99                ntt.ntt(&mut s1);
100            }
101
102            if w > tt0[j].len() {
103                tt0[j] = t[0][j].clone();
104                tt0[j].resize(w, 0.into());
105                tt1[j] = t[1][j].clone();
106                tt1[j].resize(w, 0.into());
107
108                ntt.ntt(&mut tt0[j]);
109                ntt.ntt(&mut tt1[j]);
110            }
111
112            let t0 = &tt0[j];
113            let t1 = &tt1[j];
114
115            let mut c = (0..w)
116                .map(|k| s0[k] * t0[k] + s1[k] * t1[k])
117                .collect::<Vec<_>>();
118            ntt.intt(&mut c);
119
120            c.into_iter().zip(cycle.iter()).for_each(|(x, r)| {
121                let index = r << (i + j);
122                ret[index & mask] += x;
123            });
124
125            let mut c = (0..w)
126                .map(|k| s1[k] * t0[k] + s0[k] * t1[k])
127                .collect::<Vec<_>>();
128            ntt.intt(&mut c);
129
130            c.into_iter().zip(cycle.iter()).for_each(|(x, r)| {
131                let index = r << (i + j);
132                let index = if index == 0 { 0 } else { len - index };
133                ret[index & mask] += x;
134            });
135        }
136    }
137
138    ret[0] += a[0] * b[0];
139    for i in 1..len {
140        ret[0] += a[i] * b[0] + a[0] * b[i];
141        ret[len / 2 * (i % 2)] += a[i] * b[len / 2];
142        ret[len / 2 * (i % 2)] += a[len / 2] * b[i];
143    }
144    ret[0] -= a[len / 2] * b[len / 2];
145
146    ret
147}