haar_lib/num/modint/
mod.rs

1//! 実行時にmod Mが決まるModInt
2
3use crate::impl_ops;
4pub use crate::num::ff::*;
5use std::{
6    fmt,
7    fmt::{Debug, Display, Formatter},
8    ops::Neg,
9};
10
11/// [`ModInt`]を生成するための構造体。
12#[derive(Copy, Clone, Debug, PartialEq, Eq)]
13pub struct ModIntBuilder {
14    modulo: u32,
15}
16
17impl ModIntBuilder {
18    /// `modulo`を法とする`ModIntBuilder`を生成する。
19    pub fn new(modulo: u32) -> Self {
20        assert!(modulo > 0);
21        Self { modulo }
22    }
23}
24
25impl ZZ for ModIntBuilder {
26    type Element = ModInt;
27    fn from_u64(&self, value: u64) -> Self::Element {
28        ModInt::new((value % self.modulo as u64) as u32, self.modulo)
29    }
30
31    fn from_i64(&self, value: i64) -> Self::Element {
32        let value = ((value % self.modulo as i64) + self.modulo as i64) as u32;
33        ModInt::new(value, self.modulo)
34    }
35    fn modulo(&self) -> u32 {
36        self.modulo
37    }
38}
39
40impl FF for ModIntBuilder {}
41
42/// `modulo`を法として剰余をとる構造体。
43#[derive(Copy, Clone, PartialEq, Eq)]
44pub struct ModInt {
45    value: u32,
46    modulo: u32,
47}
48
49impl ZZElem for ModInt {
50    #[inline]
51    fn value(self) -> u32 {
52        self.value
53    }
54
55    #[inline]
56    fn modulo(self) -> u32 {
57        self.modulo
58    }
59
60    fn pow(self, mut p: u64) -> Self {
61        let mut ret: u64 = 1;
62        let mut a = self.value as u64;
63
64        while p > 0 {
65            if (p & 1) != 0 {
66                ret *= a;
67                ret %= self.modulo as u64;
68            }
69
70            a *= a;
71            a %= self.modulo as u64;
72
73            p >>= 1;
74        }
75
76        Self::new_unchecked(ret as u32, self.modulo)
77    }
78}
79
80impl FFElem for ModInt {}
81
82impl ModInt {
83    /// `value`を値にもち、`modulo`を法とする`ModInt`を生成する。
84    pub fn new(value: u32, modulo: u32) -> Self {
85        assert!(modulo > 0);
86        let value = if value < modulo {
87            value
88        } else {
89            value % modulo
90        };
91        Self { value, modulo }
92    }
93
94    #[inline]
95    fn new_unchecked(value: u32, modulo: u32) -> Self {
96        Self { value, modulo }
97    }
98}
99
100impl Display for ModInt {
101    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
102        write!(f, "{}", self.value)
103    }
104}
105
106impl Debug for ModInt {
107    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
108        write!(f, "{} (mod {})", self.value, self.modulo)
109    }
110}
111
112impl_ops!(Add for ModInt, |x: Self, y: Self| {
113    assert_eq!(x.modulo, y.modulo);
114    let a = x.value + y.value;
115    Self::new_unchecked(
116        if a < x.modulo { a } else { a - x.modulo },
117        x.modulo,
118    )
119});
120impl_ops!(Sub for ModInt, |x: Self, y: Self| {
121    assert_eq!(x.modulo, y.modulo);
122    let a = if x.value < y.value {
123        x.value + x.modulo - y.value
124    } else {
125        x.value - y.value
126    };
127    Self::new_unchecked(a, x.modulo)
128});
129impl_ops!(Mul for ModInt, |x: Self, y: Self| {
130    assert_eq!(x.modulo, y.modulo);
131    let a = x.value as u64 * y.value as u64;
132    let value = if a < x.modulo as u64 {
133        a as u32
134    } else {
135        (a % x.modulo as u64) as u32
136    };
137
138    Self::new_unchecked(value, x.modulo)
139});
140impl_ops!(Div for ModInt, |x: Self, y: Self| x * y.inv());
141
142impl_ops!(AddAssign for ModInt, |x: &mut Self, y| *x = *x + y);
143impl_ops!(SubAssign for ModInt, |x: &mut Self, y| *x = *x - y);
144impl_ops!(MulAssign for ModInt, |x: &mut Self, y| *x = *x * y);
145impl_ops!(DivAssign for ModInt, |x: &mut Self, y| *x = *x / y);
146
147impl Neg for ModInt {
148    type Output = Self;
149    fn neg(self) -> Self {
150        Self::new_unchecked(
151            if self.value == 0 {
152                0
153            } else {
154                self.modulo - self.value
155            },
156            self.modulo,
157        )
158    }
159}
160
161impl From<ModInt> for u32 {
162    fn from(value: ModInt) -> Self {
163        value.value
164    }
165}