haar_lib/math/
binomial_coefficient.rs

1//! 二項係数
2//!
3//! # References
4//! - <https://ferin-tech.hatenablog.com/entry/2018/01/17/010829>
5//!
6//! # Problems
7//! - <https://judge.yosupo.jp/problem/binomial_coefficient>
8use crate::math::{
9    crt::crt_vec,
10    mod_ops::{inv::*, pow::*},
11};
12
13/// 二項係数$_nC_k \pmod{p^q}$($p$は素数)を計算する。
14#[derive(Clone)]
15pub struct ExtLucas {
16    prod: Vec<u64>,
17    inv: Vec<u64>,
18    p: u64,
19    q: u64,
20    m: u64,
21}
22
23impl ExtLucas {
24    /// 素数$p$に対して$\pmod{p^q}$で[`ExtLucas`]を用意する。
25    pub fn new(p: u64, q: u64) -> Self {
26        let m = p.pow(q as u32);
27
28        let mut prod: Vec<u64> = vec![1; m as usize];
29        let mut inv: Vec<u64> = vec![1; m as usize];
30
31        for i in 1..m as usize {
32            prod[i] = prod[i - 1] * (if i as u64 % p == 0 { 1 } else { i as u64 }) % m;
33        }
34
35        inv[m as usize - 1] = mod_inv(prod[m as usize - 1], m).unwrap();
36        for i in (1..m as usize).rev() {
37            inv[i - 1] = inv[i] * (if i as u64 % p == 0 { 1 } else { i as u64 }) % m;
38        }
39
40        Self { prod, inv, p, q, m }
41    }
42
43    /// $_nC_k \pmod{p^q}$を計算する。
44    pub fn calc(&self, mut n: u64, mut k: u64) -> u64 {
45        assert!(n >= k);
46
47        let mut r = n - k;
48        let mut e = 0;
49        let mut eq = 0;
50        let mut ret = 1;
51
52        let mut i = 0;
53        loop {
54            if n == 0 {
55                break;
56            }
57
58            ret *= self.prod[(n % self.m) as usize];
59            ret %= self.m;
60            ret *= self.inv[(k % self.m) as usize];
61            ret %= self.m;
62            ret *= self.inv[(r % self.m) as usize];
63            ret %= self.m;
64
65            n /= self.p;
66            k /= self.p;
67            r /= self.p;
68
69            e += n - k - r;
70
71            if e >= self.q {
72                return 0;
73            }
74
75            i += 1;
76            if i >= self.q {
77                eq += n - k - r;
78            }
79        }
80
81        if (self.p != 2 || self.q < 3) && eq % 2 == 1 {
82            ret = self.m - ret;
83        }
84
85        ret *= mod_pow(self.p, e, self.m);
86        ret %= self.m;
87
88        ret
89    }
90}
91
92/// 二項係数$_nC_k \pmod m$を計算する。
93#[derive(Clone)]
94pub struct BinomialCoefficient {
95    lu: Vec<ExtLucas>,
96    ms: Vec<u64>,
97}
98
99impl BinomialCoefficient {
100    /// $\pmod m$で[`BinomialCoefficient`]を用意する。
101    pub fn new(mut m: u64) -> Self {
102        let mut m_primes = vec![];
103        let mut ms = vec![];
104        let mut lu = vec![];
105
106        let mut i = 2;
107        while i * i <= m {
108            if m % i == 0 {
109                let mut t = 1;
110                let mut c = 0;
111                while m % i == 0 {
112                    m /= i;
113                    c += 1;
114                    t *= i;
115                }
116                m_primes.push((i, c));
117                ms.push(t);
118            }
119            i += 1;
120        }
121
122        if m != 1 {
123            m_primes.push((m, 1));
124            ms.push(m);
125        }
126
127        for (p, q) in m_primes {
128            lu.push(ExtLucas::new(p, q));
129        }
130
131        Self { lu, ms }
132    }
133
134    /// $_nC_k \pmod m$を計算する。
135    pub fn calc(&self, n: u64, k: u64) -> u64 {
136        if n < k {
137            0
138        } else {
139            let bs = self.lu.iter().map(|lu| lu.calc(n, k));
140            let a = bs
141                .zip(self.ms.iter())
142                .map(|(a, &b)| (a as i64, b))
143                .collect::<Vec<_>>();
144            crt_vec(&a).unwrap().0 as u64
145        }
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn test() {
155        let c = BinomialCoefficient::new(10007);
156        assert_eq!(c.calc(4, 2), 6);
157        assert_eq!(c.calc(0, 0), 1);
158        assert_eq!(c.calc(1000000007, 998244353), 0);
159
160        let c = BinomialCoefficient::new(60);
161        assert_eq!(
162            (0..=10).map(|i| c.calc(20, i)).collect::<Vec<_>>(),
163            [1, 20, 10, 0, 45, 24, 0, 0, 30, 20, 16]
164        );
165    }
166}