haar_lib/math/combinatorics/
binomial_coefficient.rs

1//! 二項係数
2//!
3//! # References
4//! - <https://ferin-tech.hatenablog.com/entry/2018/01/17/010829>
5//!
6//! # Problems
7//! - <https://judge.yosupo.jp/problem/binomial_coefficient>
8use crate::math::{
9    crt::crt_vec,
10    mod_ops::{inv::*, pow::*},
11};
12
13/// 二項係数$_nC_k \pmod{p^q}$($p$は素数)を計算する。
14#[derive(Clone)]
15pub struct ExtLucas {
16    prod: Vec<u64>,
17    inv: Vec<u64>,
18    p: u64,
19    q: u64,
20    m: u64,
21}
22
23impl ExtLucas {
24    /// 素数$p$に対して$\pmod{p^q}$で[`ExtLucas`]を用意する。
25    pub fn new(p: u64, q: u64) -> Self {
26        let m = p.pow(q as u32);
27
28        let mut prod: Vec<u64> = vec![1; m as usize];
29        let mut inv: Vec<u64> = vec![1; m as usize];
30
31        for i in 1..m as usize {
32            prod[i] = prod[i - 1] * (if i as u64 % p == 0 { 1 } else { i as u64 }) % m;
33        }
34
35        inv[m as usize - 1] = mod_inv(prod[m as usize - 1], m).unwrap();
36        for i in (1..m as usize).rev() {
37            inv[i - 1] = inv[i] * (if i as u64 % p == 0 { 1 } else { i as u64 }) % m;
38        }
39
40        Self { prod, inv, p, q, m }
41    }
42
43    /// $_nC_k \pmod{p^q}$を計算する。
44    pub fn calc(&self, mut n: u64, mut k: u64) -> u64 {
45        if n < k {
46            return 0;
47        }
48
49        let mut r = n - k;
50        let mut e = 0;
51        let mut eq = 0;
52        let mut ret = 1;
53
54        let mut i = 0;
55        loop {
56            if n == 0 {
57                break;
58            }
59
60            ret *= self.prod[(n % self.m) as usize];
61            ret %= self.m;
62            ret *= self.inv[(k % self.m) as usize];
63            ret %= self.m;
64            ret *= self.inv[(r % self.m) as usize];
65            ret %= self.m;
66
67            n /= self.p;
68            k /= self.p;
69            r /= self.p;
70
71            e += n - k - r;
72
73            if e >= self.q {
74                return 0;
75            }
76
77            i += 1;
78            if i >= self.q {
79                eq += n - k - r;
80            }
81        }
82
83        if (self.p != 2 || self.q < 3) && eq % 2 == 1 {
84            ret = self.m - ret;
85        }
86
87        ret *= mod_pow(self.p, e, self.m);
88        ret %= self.m;
89
90        ret
91    }
92}
93
94/// 二項係数$_nC_k \pmod m$を計算する。
95#[derive(Clone)]
96pub struct BinomialCoefficient {
97    lu: Vec<ExtLucas>,
98    ms: Vec<u64>,
99}
100
101impl BinomialCoefficient {
102    /// $\pmod m$で[`BinomialCoefficient`]を用意する。
103    pub fn new(mut m: u64) -> Self {
104        let mut m_primes = vec![];
105        let mut ms = vec![];
106        let mut lu = vec![];
107
108        let mut i = 2;
109        while i * i <= m {
110            if m % i == 0 {
111                let mut t = 1;
112                let mut c = 0;
113                while m % i == 0 {
114                    m /= i;
115                    c += 1;
116                    t *= i;
117                }
118                m_primes.push((i, c));
119                ms.push(t);
120            }
121            i += 1;
122        }
123
124        if m != 1 {
125            m_primes.push((m, 1));
126            ms.push(m);
127        }
128
129        for (p, q) in m_primes {
130            lu.push(ExtLucas::new(p, q));
131        }
132
133        Self { lu, ms }
134    }
135
136    /// $_nC_k \pmod m$を計算する。
137    pub fn calc(&self, n: u64, k: u64) -> u64 {
138        if n < k {
139            0
140        } else {
141            let bs = self.lu.iter().map(|lu| lu.calc(n, k));
142            let a = bs
143                .zip(self.ms.iter())
144                .map(|(a, &b)| (a as i64, b))
145                .collect::<Vec<_>>();
146            crt_vec(&a).unwrap().0 as u64
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn test() {
157        let c = BinomialCoefficient::new(10007);
158        assert_eq!(c.calc(4, 2), 6);
159        assert_eq!(c.calc(0, 0), 1);
160        assert_eq!(c.calc(1000000007, 998244353), 0);
161
162        let c = BinomialCoefficient::new(60);
163        assert_eq!(
164            (0..=10).map(|i| c.calc(20, i)).collect::<Vec<_>>(),
165            [1, 20, 10, 0, 45, 24, 0, 0, 30, 20, 16]
166        );
167    }
168}