haar_lib/math/
factorial_large.rs

1//! 階乗 $n! \pmod P$ ($0 \le n \lt P$)
2//!
3//! # References
4//! - <https://suisen-kyopro.hatenablog.com/entry/2023/11/22/201600>
5
6use crate::math::prime_mod::PrimeMod;
7use crate::{math::polynomial::shift_sampling_points::*, num::const_modint::*};
8
9/// 階乗を計算する。
10#[derive(Clone, Debug)]
11pub struct Factorial<P: PrimeMod> {
12    r: u32,
13    prod: Vec<ConstModInt<P>>,
14}
15
16impl<P: PrimeMod> Default for Factorial<P> {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl<P: PrimeMod> Factorial<P> {
23    /// 前計算を行う。
24    pub fn new() -> Self {
25        let k = 9;
26        let r = 1 << k;
27
28        let mut f = vec![ConstModInt::new(1)];
29
30        for i in 0..k {
31            let n = f.len();
32            let mut g = shift_sampling_points::<P>(f.clone(), n as u32, n * 3);
33            f.append(&mut g);
34
35            f = f
36                .chunks_exact(2)
37                .enumerate()
38                .map(|(j, f)| f[0] * f[1] * ((2 * j + 1) << i).into())
39                .collect();
40        }
41
42        let block_num = (P::PRIME_NUM / r) as usize;
43        if f.len() < block_num {
44            let mut g = shift_sampling_points::<P>(f.clone(), f.len() as u32, block_num - f.len());
45            f.append(&mut g);
46        }
47
48        let mut prod = vec![1.into(); f.len() + 1];
49        for (i, fi) in f.into_iter().enumerate() {
50            prod[i + 1] = prod[i] * fi * r.into() * (i + 1).into();
51        }
52
53        Self { r, prod }
54    }
55
56    /// $n! \pmod P$を計算する。
57    pub fn factorial(&self, n: u32) -> ConstModInt<P> {
58        if n >= P::PRIME_NUM {
59            return 0.into();
60        }
61
62        let k = n / self.r;
63        let p = k * self.r;
64        let mut ret = self.prod[k as usize];
65
66        for i in p + 1..=n {
67            ret *= i.into()
68        }
69
70        ret
71    }
72}