haar_lib/math/convolution/
conv_mul_mod2n.rs1use crate::math::convolution::ntt::*;
6use crate::math::prime_mod::*;
7use crate::num::const_modint::*;
8
9pub fn convolution_mul_mod2n<P: PrimeMod>(
14 a: Vec<ConstModInt<P>>,
15 b: Vec<ConstModInt<P>>,
16) -> Vec<ConstModInt<P>> {
17 let len = a.len();
18 assert_eq!(a.len(), b.len());
19 assert!(len.is_power_of_two());
20
21 let n = len.trailing_zeros() as usize;
22
23 if n <= 1 {
24 let mut ret = vec![0.into(); len];
25
26 for i in 0..len {
27 for j in 0..len {
28 ret[i * j % len] += a[i] * b[j];
29 }
30 }
31
32 return ret;
33 }
34
35 let ntt = NTT::<P>::new();
36 let mask = (1 << n) - 1;
37 let cycle = std::iter::successors(Some(1), |n| Some((n * 5) & mask))
38 .take(len / 4)
39 .collect::<Vec<_>>();
40
41 let mut s: Vec<Vec<Vec<ConstModInt<P>>>> = vec![vec![vec![]; n - 1]; 2];
42 let mut t: Vec<Vec<Vec<ConstModInt<P>>>> = vec![vec![vec![]; n - 1]; 2];
43
44 for i in 0..n - 1 {
45 let k = n - i;
46
47 s[0][i].resize(1 << (k - 2), 0.into());
48 s[1][i].resize(1 << (k - 2), 0.into());
49 t[0][i].resize(1 << (k - 2), 0.into());
50 t[1][i].resize(1 << (k - 2), 0.into());
51
52 let mask2 = (1 << k) - 1;
53
54 for (j, c) in cycle.iter().enumerate().take(1 << (k - 2)) {
55 let r = c & mask2;
56 s[0][i][j] = a[r << i];
57 t[0][i][j] = b[r << i];
58 s[1][i][j] = a[((len - r) & mask2) << i];
59 t[1][i][j] = b[((len - r) & mask2) << i];
60 }
61 }
62
63 let mut ret = vec![0.into(); 1 << n];
64
65 let mut tt0 = vec![vec![]; n];
66 let mut tt1 = vec![vec![]; n];
67
68 for i in (0..n - 1).rev() {
69 let mut s0 = s[0][i].clone();
70 let mut s1 = s[1][i].clone();
71 let slen = s[0][i].len();
72 ntt.ntt(&mut s0);
73 ntt.ntt(&mut s1);
74
75 for j in (0..n - 1).rev() {
76 let tlen = t[0][j].len();
77
78 if slen <= 1 || tlen <= 1 {
79 for x in 0..slen {
80 for y in 0..tlen {
81 let g = cycle[(x + y) % cycle.len()] << (i + j);
82 ret[g & mask] += s[0][i][x] * t[0][j][y] + s[1][i][x] * t[1][j][y];
83 let mg = if g == 0 { 0 } else { len - g };
84 ret[mg & mask] += s[1][i][x] * t[0][j][y] + s[0][i][x] * t[1][j][y];
85 }
86 }
87
88 continue;
89 }
90
91 let w = std::cmp::max(slen, tlen);
92
93 if w > slen {
94 s0 = s[0][i].clone();
95 s1 = s[1][i].clone();
96 s0.resize(w, 0.into());
97 s1.resize(w, 0.into());
98 ntt.ntt(&mut s0);
99 ntt.ntt(&mut s1);
100 }
101
102 if w > tt0[j].len() {
103 tt0[j] = t[0][j].clone();
104 tt0[j].resize(w, 0.into());
105 tt1[j] = t[1][j].clone();
106 tt1[j].resize(w, 0.into());
107
108 ntt.ntt(&mut tt0[j]);
109 ntt.ntt(&mut tt1[j]);
110 }
111
112 let t0 = &tt0[j];
113 let t1 = &tt1[j];
114
115 let mut c = (0..w)
116 .map(|k| s0[k] * t0[k] + s1[k] * t1[k])
117 .collect::<Vec<_>>();
118 ntt.intt(&mut c);
119
120 c.into_iter().zip(cycle.iter()).for_each(|(x, r)| {
121 let index = r << (i + j);
122 ret[index & mask] += x;
123 });
124
125 let mut c = (0..w)
126 .map(|k| s1[k] * t0[k] + s0[k] * t1[k])
127 .collect::<Vec<_>>();
128 ntt.intt(&mut c);
129
130 c.into_iter().zip(cycle.iter()).for_each(|(x, r)| {
131 let index = r << (i + j);
132 let index = if index == 0 { 0 } else { len - index };
133 ret[index & mask] += x;
134 });
135 }
136 }
137
138 ret[0] += a[0] * b[0];
139 for i in 1..len {
140 ret[0] += a[i] * b[0] + a[0] * b[i];
141 ret[len / 2 * (i % 2)] += a[i] * b[len / 2];
142 ret[len / 2 * (i % 2)] += a[len / 2] * b[i];
143 }
144 ret[0] -= a[len / 2] * b[len / 2];
145
146 ret
147}