haar_lib/num/montgomery/
mod.rs

1//! Montgomery乗算
2//!
3//! # References
4//! - <https://ja.wikipedia.org/wiki/%E3%83%A2%E3%83%B3%E3%82%B4%E3%83%A1%E3%83%AA%E4%B9%97%E7%AE%97>
5use crate::impl_ops;
6use crate::num::ff::*;
7
8/// [`Montgomery`]を生成するための構造体。
9#[derive(Clone, Debug, PartialEq, Eq)]
10pub struct MontgomeryBuilder {
11    modulo: u32,
12    r2: u64,
13    m: u64,
14}
15
16const B: u32 = 32;
17const R: u64 = 1 << B;
18const MASK: u64 = R - 1;
19
20impl MontgomeryBuilder {
21    /// `modulo`を法とする[`MontgomeryBuilder`]を生成する。
22    pub fn new(modulo: u32) -> Self {
23        assert!(modulo % 2 != 0);
24        assert!(modulo > 0);
25
26        let r = R % modulo as u64;
27        let r2 = r * r % modulo as u64;
28        let m = {
29            let mut ret = 0;
30            let mut r = R;
31            let mut i = 1;
32            let mut t = 0;
33            while r > 1 {
34                if t % 2 == 0 {
35                    t += modulo;
36                    ret += i;
37                }
38                t >>= 1;
39                r >>= 1;
40                i <<= 1;
41            }
42            ret
43        };
44
45        Self { modulo, r2, m }
46    }
47}
48
49fn reduce(value: u64, modulo: u64, m: u64) -> u64 {
50    let mut ret = ((((value & MASK) * m) & MASK) * modulo + value) >> B;
51    if ret >= modulo {
52        ret -= modulo;
53    }
54    ret
55}
56
57impl FF for MontgomeryBuilder {
58    type Element = Montgomery;
59    fn from_u64(&self, mut value: u64) -> Self::Element {
60        if value >= self.modulo as u64 {
61            value %= self.modulo as u64;
62        }
63
64        let value = reduce(value * self.r2, self.modulo as u64, self.m);
65        Montgomery::__new(value, self.modulo as u64, self.r2, self.m)
66    }
67
68    fn from_i64(&self, mut value: i64) -> Self::Element {
69        value %= self.modulo as i64;
70        if value < 0 {
71            value += self.modulo as i64;
72        }
73
74        let value = reduce(value as u64 * self.r2, self.modulo as u64, self.m);
75        Montgomery::__new(value, self.modulo as u64, self.r2, self.m)
76    }
77    fn modulo(&self) -> u32 {
78        self.modulo
79    }
80}
81
82/// `modulo`を法として剰余をとる構造体。
83#[derive(Copy, Clone, PartialEq, Eq)]
84pub struct Montgomery {
85    value: u64,
86    modulo: u64,
87    r2: u64,
88    m: u64,
89}
90
91impl FFElem for Montgomery {
92    #[inline]
93    fn value(self) -> u32 {
94        reduce(self.value, self.modulo, self.m) as u32
95    }
96
97    #[inline]
98    fn modulo(self) -> u32 {
99        self.modulo as u32
100    }
101
102    fn pow(self, mut p: u64) -> Self {
103        let mut value = reduce(self.r2, self.modulo, self.m);
104        let mut a = self.value;
105
106        while p > 0 {
107            if (p & 1) != 0 {
108                value = reduce(value * a, self.modulo, self.m);
109            }
110            a = reduce(a * a, self.modulo, self.m);
111            p >>= 1;
112        }
113
114        Self { value, ..self }
115    }
116}
117
118impl Montgomery {
119    fn __new(value: u64, modulo: u64, r2: u64, m: u64) -> Self {
120        Self {
121            value,
122            modulo,
123            r2,
124            m,
125        }
126    }
127}
128
129impl_ops!(Add for Montgomery, |mut x: Self, y| {
130    x += y;
131    x
132});
133
134impl_ops!(Sub for Montgomery, |mut x: Self, y| {
135    x -= y;
136    x
137});
138
139impl_ops!(Mul for Montgomery, |mut x: Self, y| {
140    x *= y;
141    x
142});
143
144impl_ops!(Div for Montgomery, |mut x: Self, y| {
145    x /= y;
146    x
147});
148
149impl_ops!(AddAssign for Montgomery, |x: &mut Self, y: Self| {
150    x.value += y.value;
151    if x.value >= x.modulo {
152        x.value -= x.modulo;
153    }
154});
155
156impl_ops!(SubAssign for Montgomery, |x: &mut Self, y: Self| {
157    if x.value < y.value {
158        x.value += x.modulo;
159    }
160    x.value -= y.value;
161});
162
163impl_ops!(MulAssign for Montgomery, |x: &mut Self, y: Self| x.value =
164    reduce(x.value * y.value, x.modulo, x.m));
165
166impl_ops!(DivAssign for Montgomery, |x: &mut Self, y: Self| *x *= y.inv());
167
168impl_ops!(Neg for Montgomery, |mut x: Self| {
169    if x.value != 0 {
170        x.value = x.modulo - x.value;
171    }
172    x
173});
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use crate::iter::collect::CollectVec;
179    use crate::num::{const_modint::*, modint::*};
180    use crate::timer;
181    use rand::Rng;
182
183    #[derive(Clone, Copy, Debug)]
184    enum Ops {
185        Add(u64),
186        Sub(u64),
187        Mul(u64),
188        Div(u64),
189        Neg,
190    }
191
192    #[test]
193    fn test() {
194        const MOD: u32 = 998244353;
195
196        let mut rng = rand::thread_rng();
197
198        let constmodint = ConstModIntBuilder::<MOD>;
199        let modint = ModIntBuilder::new(MOD);
200        let montgomery = MontgomeryBuilder::new(MOD);
201
202        let mut ans = constmodint.from_u64(1);
203        let mut ans2 = modint.from_u64(1);
204        let mut res = montgomery.from_u64(1);
205
206        let ops = (0..1000000)
207            .map(|_| {
208                let x = rng.gen_range(1..MOD) as u64;
209
210                let op = rng.gen_range(0..5);
211                match op {
212                    0 => Ops::Add(x),
213                    1 => Ops::Sub(x),
214                    2 => Ops::Mul(x),
215                    3 => Ops::Div(x),
216                    4 => Ops::Neg,
217                    _ => unreachable!(),
218                }
219            })
220            .collect_vec();
221
222        timer! {{
223            for &op in &ops {
224                match op {
225                    Ops::Add(x) => ans += constmodint.from_u64(x),
226                    Ops::Sub(x) => ans -= constmodint.from_u64(x),
227                    Ops::Mul(x) => ans *= constmodint.from_u64(x),
228                    Ops::Div(x) => ans /= constmodint.from_u64(x),
229                    Ops::Neg => ans = -ans
230                }
231            }
232        }};
233
234        timer! {{
235            for &op in &ops {
236                match op {
237                    Ops::Add(x) => ans2 += modint.from_u64(x),
238                    Ops::Sub(x) => ans2 -= modint.from_u64(x),
239                    Ops::Mul(x) => ans2 *= modint.from_u64(x),
240                    Ops::Div(x) => ans2 /= modint.from_u64(x),
241                    Ops::Neg => ans2 = -ans2
242                }
243            }
244        }};
245
246        timer! {{
247            for &op in &ops {
248                match op {
249                    Ops::Add(x) => res += montgomery.from_u64(x),
250                    Ops::Sub(x) => res -= montgomery.from_u64(x),
251                    Ops::Mul(x) => res *= montgomery.from_u64(x),
252                    Ops::Div(x) => res /= montgomery.from_u64(x),
253                    Ops::Neg => res = -res
254                }
255            }
256        }};
257
258        dbg!(ans.value());
259        dbg!(ans2.value());
260        dbg!(res.value());
261
262        assert_eq!(ans.value(), ans2.value());
263        assert_eq!(ans.value(), res.value());
264    }
265}