haar_lib/num/montgomery/
mod.rs

1//! Montgomery乗算
2//!
3//! # References
4//! - <https://ja.wikipedia.org/wiki/%E3%83%A2%E3%83%B3%E3%82%B4%E3%83%A1%E3%83%AA%E4%B9%97%E7%AE%97>
5use crate::impl_ops;
6use crate::num::ff::*;
7
8/// [`Montgomery`]を生成するための構造体。
9#[derive(Clone, Debug, PartialEq, Eq)]
10pub struct MontgomeryBuilder {
11    modulo: u32,
12    r2: u64,
13    m: u64,
14}
15
16const B: u32 = 32;
17const R: u64 = 1 << B;
18const MASK: u64 = R - 1;
19
20impl MontgomeryBuilder {
21    /// `modulo`を法とする[`MontgomeryBuilder`]を生成する。
22    pub fn new(modulo: u32) -> Self {
23        assert!(modulo % 2 != 0);
24        assert!(modulo > 0);
25
26        let r = R % modulo as u64;
27        let r2 = r * r % modulo as u64;
28        let m = {
29            let mut ret = 0;
30            let mut r = R;
31            let mut i = 1;
32            let mut t = 0;
33            while r > 1 {
34                if t % 2 == 0 {
35                    t += modulo;
36                    ret += i;
37                }
38                t >>= 1;
39                r >>= 1;
40                i <<= 1;
41            }
42            ret
43        };
44
45        Self { modulo, r2, m }
46    }
47}
48
49fn reduce(value: u64, modulo: u64, m: u64) -> u64 {
50    let mut ret = ((((value & MASK) * m) & MASK) * modulo + value) >> B;
51    if ret >= modulo {
52        ret -= modulo;
53    }
54    ret
55}
56
57impl ZZ for MontgomeryBuilder {
58    type Element = Montgomery;
59    fn from_u64(&self, mut value: u64) -> Self::Element {
60        if value >= self.modulo as u64 {
61            value %= self.modulo as u64;
62        }
63
64        let value = reduce(value * self.r2, self.modulo as u64, self.m);
65        Montgomery::__new(value, self.modulo as u64, self.r2, self.m)
66    }
67
68    fn from_i64(&self, mut value: i64) -> Self::Element {
69        value %= self.modulo as i64;
70        if value < 0 {
71            value += self.modulo as i64;
72        }
73
74        let value = reduce(value as u64 * self.r2, self.modulo as u64, self.m);
75        Montgomery::__new(value, self.modulo as u64, self.r2, self.m)
76    }
77    fn modulo(&self) -> u32 {
78        self.modulo
79    }
80}
81
82impl FF for MontgomeryBuilder {}
83
84/// `modulo`を法として剰余をとる構造体。
85#[derive(Copy, Clone, PartialEq, Eq)]
86pub struct Montgomery {
87    value: u64,
88    modulo: u64,
89    r2: u64,
90    m: u64,
91}
92
93impl ZZElem for Montgomery {
94    #[inline]
95    fn value(self) -> u32 {
96        reduce(self.value, self.modulo, self.m) as u32
97    }
98
99    #[inline]
100    fn modulo(self) -> u32 {
101        self.modulo as u32
102    }
103
104    fn pow(self, mut p: u64) -> Self {
105        let mut value = reduce(self.r2, self.modulo, self.m);
106        let mut a = self.value;
107
108        while p > 0 {
109            if (p & 1) != 0 {
110                value = reduce(value * a, self.modulo, self.m);
111            }
112            a = reduce(a * a, self.modulo, self.m);
113            p >>= 1;
114        }
115
116        Self { value, ..self }
117    }
118}
119
120impl FFElem for Montgomery {}
121
122impl Montgomery {
123    fn __new(value: u64, modulo: u64, r2: u64, m: u64) -> Self {
124        Self {
125            value,
126            modulo,
127            r2,
128            m,
129        }
130    }
131}
132
133impl_ops!(Add for Montgomery, |mut x: Self, y| {
134    x += y;
135    x
136});
137
138impl_ops!(Sub for Montgomery, |mut x: Self, y| {
139    x -= y;
140    x
141});
142
143impl_ops!(Mul for Montgomery, |mut x: Self, y| {
144    x *= y;
145    x
146});
147
148impl_ops!(Div for Montgomery, |mut x: Self, y| {
149    x /= y;
150    x
151});
152
153impl_ops!(AddAssign for Montgomery, |x: &mut Self, y: Self| {
154    x.value += y.value;
155    if x.value >= x.modulo {
156        x.value -= x.modulo;
157    }
158});
159
160impl_ops!(SubAssign for Montgomery, |x: &mut Self, y: Self| {
161    if x.value < y.value {
162        x.value += x.modulo;
163    }
164    x.value -= y.value;
165});
166
167impl_ops!(MulAssign for Montgomery, |x: &mut Self, y: Self| x.value =
168    reduce(x.value * y.value, x.modulo, x.m));
169
170impl_ops!(DivAssign for Montgomery, |x: &mut Self, y: Self| *x *= y.inv());
171
172impl_ops!(Neg for Montgomery, |mut x: Self| {
173    if x.value != 0 {
174        x.value = x.modulo - x.value;
175    }
176    x
177});
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::iter::collect::CollectVec;
183    use crate::math::prime_mod::Prime;
184    use crate::num::{const_modint::*, modint::*};
185    use crate::timer;
186    use rand::Rng;
187
188    #[derive(Clone, Copy, Debug)]
189    enum Ops {
190        Add(u64),
191        Sub(u64),
192        Mul(u64),
193        Div(u64),
194        Neg,
195    }
196
197    #[test]
198    fn test() {
199        const MOD: u32 = 998244353;
200
201        let mut rng = rand::thread_rng();
202
203        let constmodint = ConstModIntBuilder::<Prime<MOD>>::new();
204        let modint = ModIntBuilder::new(MOD);
205        let montgomery = MontgomeryBuilder::new(MOD);
206
207        let mut ans = constmodint.from_u64(1);
208        let mut ans2 = modint.from_u64(1);
209        let mut res = montgomery.from_u64(1);
210
211        let ops = std::iter::repeat_with(|| {
212            let x = rng.gen_range(1..MOD) as u64;
213
214            let op = rng.gen_range(0..5);
215            match op {
216                0 => Ops::Add(x),
217                1 => Ops::Sub(x),
218                2 => Ops::Mul(x),
219                3 => Ops::Div(x),
220                4 => Ops::Neg,
221                _ => unreachable!(),
222            }
223        })
224        .take(1000000)
225        .collect_vec();
226
227        timer! {{
228            for &op in &ops {
229                match op {
230                    Ops::Add(x) => ans += constmodint.from_u64(x),
231                    Ops::Sub(x) => ans -= constmodint.from_u64(x),
232                    Ops::Mul(x) => ans *= constmodint.from_u64(x),
233                    Ops::Div(x) => ans /= constmodint.from_u64(x),
234                    Ops::Neg => ans = -ans
235                }
236            }
237        }};
238
239        timer! {{
240            for &op in &ops {
241                match op {
242                    Ops::Add(x) => ans2 += modint.from_u64(x),
243                    Ops::Sub(x) => ans2 -= modint.from_u64(x),
244                    Ops::Mul(x) => ans2 *= modint.from_u64(x),
245                    Ops::Div(x) => ans2 /= modint.from_u64(x),
246                    Ops::Neg => ans2 = -ans2
247                }
248            }
249        }};
250
251        timer! {{
252            for &op in &ops {
253                match op {
254                    Ops::Add(x) => res += montgomery.from_u64(x),
255                    Ops::Sub(x) => res -= montgomery.from_u64(x),
256                    Ops::Mul(x) => res *= montgomery.from_u64(x),
257                    Ops::Div(x) => res /= montgomery.from_u64(x),
258                    Ops::Neg => res = -res
259                }
260            }
261        }};
262
263        dbg!(ans.value());
264        dbg!(ans2.value());
265        dbg!(res.value());
266
267        assert_eq!(ans.value(), ans2.value());
268        assert_eq!(ans.value(), res.value());
269    }
270}