1use crate::impl_ops;
6use crate::num::ff::*;
7
8#[derive(Clone, Debug, PartialEq, Eq)]
10pub struct MontgomeryBuilder {
11 modulo: u32,
12 r2: u64,
13 m: u64,
14}
15
16const B: u32 = 32;
17const R: u64 = 1 << B;
18const MASK: u64 = R - 1;
19
20impl MontgomeryBuilder {
21 pub fn new(modulo: u32) -> Self {
23 assert!(modulo % 2 != 0);
24 assert!(modulo > 0);
25
26 let r = R % modulo as u64;
27 let r2 = r * r % modulo as u64;
28 let m = {
29 let mut ret = 0;
30 let mut r = R;
31 let mut i = 1;
32 let mut t = 0;
33 while r > 1 {
34 if t % 2 == 0 {
35 t += modulo;
36 ret += i;
37 }
38 t >>= 1;
39 r >>= 1;
40 i <<= 1;
41 }
42 ret
43 };
44
45 Self { modulo, r2, m }
46 }
47}
48
49fn reduce(value: u64, modulo: u64, m: u64) -> u64 {
50 let mut ret = ((((value & MASK) * m) & MASK) * modulo + value) >> B;
51 if ret >= modulo {
52 ret -= modulo;
53 }
54 ret
55}
56
57impl ZZ for MontgomeryBuilder {
58 type Element = Montgomery;
59 fn from_u64(&self, mut value: u64) -> Self::Element {
60 if value >= self.modulo as u64 {
61 value %= self.modulo as u64;
62 }
63
64 let value = reduce(value * self.r2, self.modulo as u64, self.m);
65 Montgomery::__new(value, self.modulo as u64, self.r2, self.m)
66 }
67
68 fn from_i64(&self, mut value: i64) -> Self::Element {
69 value %= self.modulo as i64;
70 if value < 0 {
71 value += self.modulo as i64;
72 }
73
74 let value = reduce(value as u64 * self.r2, self.modulo as u64, self.m);
75 Montgomery::__new(value, self.modulo as u64, self.r2, self.m)
76 }
77 fn modulo(&self) -> u32 {
78 self.modulo
79 }
80}
81
82impl FF for MontgomeryBuilder {}
83
84#[derive(Copy, Clone, PartialEq, Eq)]
86pub struct Montgomery {
87 value: u64,
88 modulo: u64,
89 r2: u64,
90 m: u64,
91}
92
93impl ZZElem for Montgomery {
94 #[inline]
95 fn value(self) -> u32 {
96 reduce(self.value, self.modulo, self.m) as u32
97 }
98
99 #[inline]
100 fn modulo(self) -> u32 {
101 self.modulo as u32
102 }
103
104 fn pow(self, mut p: u64) -> Self {
105 let mut value = reduce(self.r2, self.modulo, self.m);
106 let mut a = self.value;
107
108 while p > 0 {
109 if (p & 1) != 0 {
110 value = reduce(value * a, self.modulo, self.m);
111 }
112 a = reduce(a * a, self.modulo, self.m);
113 p >>= 1;
114 }
115
116 Self { value, ..self }
117 }
118}
119
120impl FFElem for Montgomery {}
121
122impl Montgomery {
123 fn __new(value: u64, modulo: u64, r2: u64, m: u64) -> Self {
124 Self {
125 value,
126 modulo,
127 r2,
128 m,
129 }
130 }
131}
132
133impl_ops!(Add for Montgomery, |mut x: Self, y| {
134 x += y;
135 x
136});
137
138impl_ops!(Sub for Montgomery, |mut x: Self, y| {
139 x -= y;
140 x
141});
142
143impl_ops!(Mul for Montgomery, |mut x: Self, y| {
144 x *= y;
145 x
146});
147
148impl_ops!(Div for Montgomery, |mut x: Self, y| {
149 x /= y;
150 x
151});
152
153impl_ops!(AddAssign for Montgomery, |x: &mut Self, y: Self| {
154 x.value += y.value;
155 if x.value >= x.modulo {
156 x.value -= x.modulo;
157 }
158});
159
160impl_ops!(SubAssign for Montgomery, |x: &mut Self, y: Self| {
161 if x.value < y.value {
162 x.value += x.modulo;
163 }
164 x.value -= y.value;
165});
166
167impl_ops!(MulAssign for Montgomery, |x: &mut Self, y: Self| x.value =
168 reduce(x.value * y.value, x.modulo, x.m));
169
170impl_ops!(DivAssign for Montgomery, |x: &mut Self, y: Self| *x *= y.inv());
171
172impl_ops!(Neg for Montgomery, |mut x: Self| {
173 if x.value != 0 {
174 x.value = x.modulo - x.value;
175 }
176 x
177});
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::iter::collect::CollectVec;
183 use crate::math::prime_mod::Prime;
184 use crate::num::{const_modint::*, modint::*};
185 use crate::timer;
186 use rand::Rng;
187
188 #[derive(Clone, Copy, Debug)]
189 enum Ops {
190 Add(u64),
191 Sub(u64),
192 Mul(u64),
193 Div(u64),
194 Neg,
195 }
196
197 #[test]
198 fn test() {
199 const MOD: u32 = 998244353;
200
201 let mut rng = rand::thread_rng();
202
203 let constmodint = ConstModIntBuilder::<Prime<MOD>>::new();
204 let modint = ModIntBuilder::new(MOD);
205 let montgomery = MontgomeryBuilder::new(MOD);
206
207 let mut ans = constmodint.from_u64(1);
208 let mut ans2 = modint.from_u64(1);
209 let mut res = montgomery.from_u64(1);
210
211 let ops = std::iter::repeat_with(|| {
212 let x = rng.gen_range(1..MOD) as u64;
213
214 let op = rng.gen_range(0..5);
215 match op {
216 0 => Ops::Add(x),
217 1 => Ops::Sub(x),
218 2 => Ops::Mul(x),
219 3 => Ops::Div(x),
220 4 => Ops::Neg,
221 _ => unreachable!(),
222 }
223 })
224 .take(1000000)
225 .collect_vec();
226
227 timer! {{
228 for &op in &ops {
229 match op {
230 Ops::Add(x) => ans += constmodint.from_u64(x),
231 Ops::Sub(x) => ans -= constmodint.from_u64(x),
232 Ops::Mul(x) => ans *= constmodint.from_u64(x),
233 Ops::Div(x) => ans /= constmodint.from_u64(x),
234 Ops::Neg => ans = -ans
235 }
236 }
237 }};
238
239 timer! {{
240 for &op in &ops {
241 match op {
242 Ops::Add(x) => ans2 += modint.from_u64(x),
243 Ops::Sub(x) => ans2 -= modint.from_u64(x),
244 Ops::Mul(x) => ans2 *= modint.from_u64(x),
245 Ops::Div(x) => ans2 /= modint.from_u64(x),
246 Ops::Neg => ans2 = -ans2
247 }
248 }
249 }};
250
251 timer! {{
252 for &op in &ops {
253 match op {
254 Ops::Add(x) => res += montgomery.from_u64(x),
255 Ops::Sub(x) => res -= montgomery.from_u64(x),
256 Ops::Mul(x) => res *= montgomery.from_u64(x),
257 Ops::Div(x) => res /= montgomery.from_u64(x),
258 Ops::Neg => res = -res
259 }
260 }
261 }};
262
263 dbg!(ans.value());
264 dbg!(ans2.value());
265 dbg!(res.value());
266
267 assert_eq!(ans.value(), ans2.value());
268 assert_eq!(ans.value(), res.value());
269 }
270}