1use crate::impl_ops;
6use crate::num::ff::*;
7
8#[derive(Clone, Debug, PartialEq, Eq)]
10pub struct MontgomeryBuilder {
11 modulo: u32,
12 r2: u64,
13 m: u64,
14}
15
16const B: u32 = 32;
17const R: u64 = 1 << B;
18const MASK: u64 = R - 1;
19
20impl MontgomeryBuilder {
21 pub fn new(modulo: u32) -> Self {
23 assert!(modulo % 2 != 0);
24 assert!(modulo > 0);
25
26 let r = R % modulo as u64;
27 let r2 = r * r % modulo as u64;
28 let m = {
29 let mut ret = 0;
30 let mut r = R;
31 let mut i = 1;
32 let mut t = 0;
33 while r > 1 {
34 if t % 2 == 0 {
35 t += modulo;
36 ret += i;
37 }
38 t >>= 1;
39 r >>= 1;
40 i <<= 1;
41 }
42 ret
43 };
44
45 Self { modulo, r2, m }
46 }
47}
48
49fn reduce(value: u64, modulo: u64, m: u64) -> u64 {
50 let mut ret = ((((value & MASK) * m) & MASK) * modulo + value) >> B;
51 if ret >= modulo {
52 ret -= modulo;
53 }
54 ret
55}
56
57impl FF for MontgomeryBuilder {
58 type Element = Montgomery;
59 fn from_u64(&self, mut value: u64) -> Self::Element {
60 if value >= self.modulo as u64 {
61 value %= self.modulo as u64;
62 }
63
64 let value = reduce(value * self.r2, self.modulo as u64, self.m);
65 Montgomery::__new(value, self.modulo as u64, self.r2, self.m)
66 }
67
68 fn from_i64(&self, mut value: i64) -> Self::Element {
69 value %= self.modulo as i64;
70 if value < 0 {
71 value += self.modulo as i64;
72 }
73
74 let value = reduce(value as u64 * self.r2, self.modulo as u64, self.m);
75 Montgomery::__new(value, self.modulo as u64, self.r2, self.m)
76 }
77 fn modulo(&self) -> u32 {
78 self.modulo
79 }
80}
81
82#[derive(Copy, Clone, PartialEq, Eq)]
84pub struct Montgomery {
85 value: u64,
86 modulo: u64,
87 r2: u64,
88 m: u64,
89}
90
91impl FFElem for Montgomery {
92 #[inline]
93 fn value(self) -> u32 {
94 reduce(self.value, self.modulo, self.m) as u32
95 }
96
97 #[inline]
98 fn modulo(self) -> u32 {
99 self.modulo as u32
100 }
101
102 fn pow(self, mut p: u64) -> Self {
103 let mut value = reduce(self.r2, self.modulo, self.m);
104 let mut a = self.value;
105
106 while p > 0 {
107 if (p & 1) != 0 {
108 value = reduce(value * a, self.modulo, self.m);
109 }
110 a = reduce(a * a, self.modulo, self.m);
111 p >>= 1;
112 }
113
114 Self { value, ..self }
115 }
116}
117
118impl Montgomery {
119 fn __new(value: u64, modulo: u64, r2: u64, m: u64) -> Self {
120 Self {
121 value,
122 modulo,
123 r2,
124 m,
125 }
126 }
127}
128
129impl_ops!(Add for Montgomery, |mut x: Self, y| {
130 x += y;
131 x
132});
133
134impl_ops!(Sub for Montgomery, |mut x: Self, y| {
135 x -= y;
136 x
137});
138
139impl_ops!(Mul for Montgomery, |mut x: Self, y| {
140 x *= y;
141 x
142});
143
144impl_ops!(Div for Montgomery, |mut x: Self, y| {
145 x /= y;
146 x
147});
148
149impl_ops!(AddAssign for Montgomery, |x: &mut Self, y: Self| {
150 x.value += y.value;
151 if x.value >= x.modulo {
152 x.value -= x.modulo;
153 }
154});
155
156impl_ops!(SubAssign for Montgomery, |x: &mut Self, y: Self| {
157 if x.value < y.value {
158 x.value += x.modulo;
159 }
160 x.value -= y.value;
161});
162
163impl_ops!(MulAssign for Montgomery, |x: &mut Self, y: Self| x.value =
164 reduce(x.value * y.value, x.modulo, x.m));
165
166impl_ops!(DivAssign for Montgomery, |x: &mut Self, y: Self| *x *= y.inv());
167
168impl_ops!(Neg for Montgomery, |mut x: Self| {
169 if x.value != 0 {
170 x.value = x.modulo - x.value;
171 }
172 x
173});
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use crate::iter::collect::CollectVec;
179 use crate::num::{const_modint::*, modint::*};
180 use crate::timer;
181 use rand::Rng;
182
183 #[derive(Clone, Copy, Debug)]
184 enum Ops {
185 Add(u64),
186 Sub(u64),
187 Mul(u64),
188 Div(u64),
189 Neg,
190 }
191
192 #[test]
193 fn test() {
194 const MOD: u32 = 998244353;
195
196 let mut rng = rand::thread_rng();
197
198 let constmodint = ConstModIntBuilder::<MOD>;
199 let modint = ModIntBuilder::new(MOD);
200 let montgomery = MontgomeryBuilder::new(MOD);
201
202 let mut ans = constmodint.from_u64(1);
203 let mut ans2 = modint.from_u64(1);
204 let mut res = montgomery.from_u64(1);
205
206 let ops = (0..1000000)
207 .map(|_| {
208 let x = rng.gen_range(1..MOD) as u64;
209
210 let op = rng.gen_range(0..5);
211 match op {
212 0 => Ops::Add(x),
213 1 => Ops::Sub(x),
214 2 => Ops::Mul(x),
215 3 => Ops::Div(x),
216 4 => Ops::Neg,
217 _ => unreachable!(),
218 }
219 })
220 .collect_vec();
221
222 timer! {{
223 for &op in &ops {
224 match op {
225 Ops::Add(x) => ans += constmodint.from_u64(x),
226 Ops::Sub(x) => ans -= constmodint.from_u64(x),
227 Ops::Mul(x) => ans *= constmodint.from_u64(x),
228 Ops::Div(x) => ans /= constmodint.from_u64(x),
229 Ops::Neg => ans = -ans
230 }
231 }
232 }};
233
234 timer! {{
235 for &op in &ops {
236 match op {
237 Ops::Add(x) => ans2 += modint.from_u64(x),
238 Ops::Sub(x) => ans2 -= modint.from_u64(x),
239 Ops::Mul(x) => ans2 *= modint.from_u64(x),
240 Ops::Div(x) => ans2 /= modint.from_u64(x),
241 Ops::Neg => ans2 = -ans2
242 }
243 }
244 }};
245
246 timer! {{
247 for &op in &ops {
248 match op {
249 Ops::Add(x) => res += montgomery.from_u64(x),
250 Ops::Sub(x) => res -= montgomery.from_u64(x),
251 Ops::Mul(x) => res *= montgomery.from_u64(x),
252 Ops::Div(x) => res /= montgomery.from_u64(x),
253 Ops::Neg => res = -res
254 }
255 }
256 }};
257
258 dbg!(ans.value());
259 dbg!(ans2.value());
260 dbg!(res.value());
261
262 assert_eq!(ans.value(), ans2.value());
263 assert_eq!(ans.value(), res.value());
264 }
265}