haar_lib/math/polynomial/
mod.rs

1//! 多項式
2
3pub mod multipoint_eval;
4pub mod polynomial_interpolation;
5pub mod polynomial_taylor_shift;
6pub mod shift_sampling_points;
7pub mod sparse;
8
9use std::ops::{
10    Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Rem, RemAssign, Sub, SubAssign,
11};
12
13use crate::math::convolution::ntt::NTT;
14use crate::math::prime_mod::PrimeMod;
15use crate::num::const_modint::*;
16
17/// $\mathbb{F}_p$上の多項式
18#[derive(Clone, Debug, Default)]
19pub struct Polynomial<P: PrimeMod> {
20    pub(crate) data: Vec<ConstModInt<P>>,
21}
22
23impl<P: PrimeMod> Polynomial<P> {
24    pub(crate) const NTT: NTT<P> = NTT::<P>::new();
25
26    /// 零多項式を得る。
27    pub fn zero() -> Self {
28        Self { data: vec![] }
29    }
30
31    /// 定数項のみをもつ多項式を生成する。
32    pub fn constant(a: ConstModInt<P>) -> Self {
33        if a.value() == 0 {
34            Self::zero()
35        } else {
36            Self { data: vec![a] }
37        }
38    }
39
40    /// $x^i$の係数を得る。
41    pub fn coeff_of(&self, i: usize) -> ConstModInt<P> {
42        self.data.get(i).map_or(ConstModInt::new(0), |a| *a)
43    }
44
45    /// 多項式に値`p`を代入した結果を求める。
46    pub fn eval(&self, p: ConstModInt<P>) -> ConstModInt<P> {
47        let mut ret = ConstModInt::new(0);
48        let mut x = ConstModInt::new(1);
49
50        for &a in &self.data {
51            ret += a * x;
52            x *= p;
53        }
54
55        ret
56    }
57
58    /// 内部の`Vec`の長さを返す。
59    pub fn len(&self) -> usize {
60        self.data.len()
61    }
62
63    /// 項数が`0`のとき`true`を返す。
64    pub fn is_empty(&self) -> bool {
65        self.data.is_empty()
66    }
67
68    /// 係数が`0`の高次項を縮める。
69    pub fn shrink(&mut self) {
70        while self.data.last().is_some_and(|x| x.value() == 0) {
71            self.data.pop();
72        }
73    }
74
75    /// [`len()`](Self::len())を超えないように、先頭`t`項をもつ多項式を返す。
76    pub fn get_until(&self, t: usize) -> Self {
77        Self {
78            data: self.data[..t.min(self.len())].to_vec(),
79        }
80    }
81
82    /// 多項式の次数を返す。
83    ///
84    /// `self`が零多項式のときは`None`を返す。
85    ///
86    /// **Time complexity** $O(n)$
87    pub fn deg(&self) -> Option<usize> {
88        (0..self.len()).rev().find(|&i| self.data[i].value() != 0)
89    }
90
91    /// 多項式を`k`倍する。
92    pub fn scale(&mut self, k: ConstModInt<P>) {
93        self.data.iter_mut().for_each(|x| *x *= k);
94    }
95
96    /// 多項式を微分する。
97    pub fn differentiate(&mut self) {
98        let n = self.len();
99        if n > 0 {
100            for i in 0..n - 1 {
101                self.data[i] = self.data[i + 1] * ConstModInt::new(i as u32 + 1);
102            }
103            self.data.pop();
104        }
105    }
106
107    /// 多項式を積分する。
108    pub fn integrate(&mut self) {
109        let n = self.len();
110        let mut invs = vec![ConstModInt::new(1); n + 1];
111        for i in 2..=n {
112            invs[i] = -invs[P::PRIME_NUM as usize % i] * ConstModInt::new(P::PRIME_NUM / i as u32);
113        }
114        self.data.push(0.into());
115        for i in (0..n).rev() {
116            self.data[i + 1] = self.data[i] * invs[i + 1];
117        }
118        self.data[0] = 0.into();
119    }
120
121    /// 係数を`k`次だけ高次側にずらす。ただし、$x^n$の項以降は無視する。
122    ///
123    /// $(a_0 + a_1 x + a_2 x^2 + \ldots + a_{n-1} x^{n-1}) \times x^k \pmod {x^n}$
124    pub fn shift_higher(&mut self, k: usize) {
125        let n = self.len();
126        for i in (k..n).rev() {
127            self.data[i] = self.data[i - k];
128        }
129        for i in 0..k {
130            self.data[i] = 0.into();
131        }
132    }
133
134    /// 係数を`k`次だけ低次側にずらす。ただし、負の次数の項は無視する。
135    pub fn shift_lower(&mut self, k: usize) {
136        let n = self.len();
137        for i in 0..n.saturating_sub(k) {
138            self.data[i] = self.data[i + k];
139        }
140        for i in n.saturating_sub(k)..n {
141            self.data[i] = 0.into();
142        }
143    }
144
145    /// 多項式の列の積を計算する。
146    pub fn prod(mut a: Vec<Self>) -> Self {
147        match a.len() {
148            0 => Self::constant(1.into()),
149            1 => a.pop().unwrap(),
150            n => {
151                let b = a.split_off(n / 2);
152                Self::prod(a) * Self::prod(b)
153            }
154        }
155    }
156
157    /// 多項式の$p$乗を計算する。
158    pub fn pow(self, mut p: u64) -> Self {
159        let mut ret = Self::constant(1.into());
160        let mut a = self;
161
162        while p > 0 {
163            if p & 1 == 1 {
164                ret *= a.clone();
165            }
166
167            a = a.sq();
168            p >>= 1;
169        }
170
171        ret
172    }
173
174    /// 多項式`a`の2乗を返す。
175    pub fn sq(mut self) -> Self {
176        let k = self.len() * 2 - 1;
177        let n = k.next_power_of_two();
178
179        self.data.resize(n, 0.into());
180        Self::NTT.ntt(&mut self.data);
181        self.data.iter_mut().for_each(|x| *x *= *x);
182        Self::NTT.intt(&mut self.data);
183
184        self.data.truncate(k);
185        self
186    }
187
188    #[allow(missing_docs)]
189    pub fn inv(self, n: usize) -> Self {
190        let mut t = 1;
191        let mut ret = vec![self.data[0].inv()];
192        let a: Vec<_> = self.into();
193
194        while t <= n * 2 {
195            let k = (t * 2 - 1).next_power_of_two();
196
197            let mut s = ret.clone();
198            s.resize(k, 0.into());
199            Self::NTT.ntt(&mut s);
200            s.iter_mut().for_each(|x| *x *= *x);
201
202            let mut a = a[..t.min(a.len())].to_vec();
203            a.resize(k, 0.into());
204            Self::NTT.ntt(&mut a);
205
206            s.iter_mut().zip(a).for_each(|(x, y)| *x *= y);
207            Self::NTT.intt(&mut s);
208
209            ret.resize(t, 0.into());
210            ret.iter_mut()
211                .zip(s)
212                .for_each(|(x, y)| *x = *x * 2.into() - y);
213
214            t *= 2;
215        }
216
217        ret.into()
218    }
219
220    /// 多項式`a`の多項式`b`による商と剰余を返す。
221    pub fn divrem(self, b: Self) -> (Self, Self) {
222        if self.len() < b.len() {
223            return (Self::zero(), self);
224        }
225
226        let q = self.clone() / b.clone();
227
228        let d = b.len() - 1;
229        let mut r = self - b * q.clone();
230        r.data.truncate(d);
231        r.shrink();
232
233        (q, r)
234    }
235}
236
237impl<P: PrimeMod> AddAssign for Polynomial<P> {
238    fn add_assign(&mut self, b: Self) {
239        if self.len() < b.len() {
240            self.data.resize(b.len(), ConstModInt::new(0));
241        }
242        for (a, b) in self.data.iter_mut().zip(b.data) {
243            *a += b;
244        }
245    }
246}
247
248impl<P: PrimeMod> Add for Polynomial<P> {
249    type Output = Self;
250    fn add(mut self, b: Self) -> Self {
251        self += b;
252        self
253    }
254}
255
256impl<P: PrimeMod> SubAssign for Polynomial<P> {
257    fn sub_assign(&mut self, b: Self) {
258        if self.len() < b.len() {
259            self.data.resize(b.len(), ConstModInt::new(0));
260        }
261        for (a, b) in self.data.iter_mut().zip(b.data) {
262            *a -= b;
263        }
264    }
265}
266
267impl<P: PrimeMod> Sub for Polynomial<P> {
268    type Output = Self;
269    fn sub(mut self, b: Self) -> Self {
270        self -= b;
271        self
272    }
273}
274
275impl<P: PrimeMod> MulAssign for Polynomial<P> {
276    fn mul_assign(&mut self, mut rhs: Self) {
277        let k = self.len() + rhs.len() - 1;
278
279        let n = k.next_power_of_two();
280        self.data.resize(n, 0.into());
281        Self::NTT.ntt(&mut self.data);
282
283        rhs.data.resize(n, 0.into());
284        Self::NTT.ntt(&mut rhs.data);
285
286        self.data
287            .iter_mut()
288            .zip(rhs.data)
289            .for_each(|(x, y)| *x *= y);
290        Self::NTT.intt(&mut self.data);
291
292        self.data.truncate(k);
293    }
294}
295
296impl<P: PrimeMod> Mul for Polynomial<P> {
297    type Output = Self;
298    fn mul(mut self, rhs: Self) -> Self::Output {
299        self *= rhs;
300        self
301    }
302}
303
304impl<P: PrimeMod> DivAssign for Polynomial<P> {
305    fn div_assign(&mut self, rhs: Self) {
306        *self = self.clone() / rhs;
307    }
308}
309
310impl<P: PrimeMod> Div for Polynomial<P> {
311    type Output = Self;
312    fn div(mut self, mut rhs: Self) -> Self::Output {
313        if self.len() < rhs.len() {
314            return Self::zero();
315        }
316
317        let m = self.len() - rhs.len();
318
319        self.data.reverse();
320        rhs.data.reverse();
321
322        rhs = rhs.inv(m);
323        rhs.data.resize(m + 1, 0.into());
324
325        let mut q = self * rhs;
326        q.data.resize(m + 1, 0.into());
327        q.data.reverse();
328        q.shrink();
329        q
330    }
331}
332
333impl<P: PrimeMod> RemAssign for Polynomial<P> {
334    fn rem_assign(&mut self, rhs: Self) {
335        *self = self.clone() % rhs;
336    }
337}
338
339impl<P: PrimeMod> Rem for Polynomial<P> {
340    type Output = Self;
341    fn rem(self, rhs: Self) -> Self::Output {
342        self.divrem(rhs).1
343    }
344}
345
346impl<P: PrimeMod> PartialEq for Polynomial<P> {
347    fn eq(&self, other: &Self) -> bool {
348        let n = self.len().max(other.len());
349        for i in 0..n {
350            if self.coeff_of(i) != other.coeff_of(i) {
351                return false;
352            }
353        }
354        true
355    }
356}
357
358impl<P: PrimeMod> Eq for Polynomial<P> {}
359
360impl<P: PrimeMod> From<Polynomial<P>> for Vec<ConstModInt<P>> {
361    fn from(value: Polynomial<P>) -> Self {
362        value.data
363    }
364}
365
366impl<T, P: PrimeMod> From<Vec<T>> for Polynomial<P>
367where
368    T: Into<ConstModInt<P>>,
369{
370    fn from(value: Vec<T>) -> Self {
371        Self {
372            data: value.into_iter().map(Into::into).collect(),
373        }
374    }
375}
376
377impl<P: PrimeMod> AsRef<[ConstModInt<P>]> for Polynomial<P> {
378    fn as_ref(&self) -> &[ConstModInt<P>] {
379        &self.data
380    }
381}
382
383impl<P: PrimeMod> AsMut<Vec<ConstModInt<P>>> for Polynomial<P> {
384    fn as_mut(&mut self) -> &mut Vec<ConstModInt<P>> {
385        &mut self.data
386    }
387}
388
389impl<P: PrimeMod> Index<usize> for Polynomial<P> {
390    type Output = ConstModInt<P>;
391    fn index(&self, index: usize) -> &Self::Output {
392        &self.data[index]
393    }
394}
395
396impl<P: PrimeMod> IndexMut<usize> for Polynomial<P> {
397    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
398        &mut self.data[index]
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use crate::{math::prime_mod::Prime, num::const_modint::ConstModIntBuilder};
405
406    use super::*;
407
408    type P = Prime<998244353>;
409
410    #[test]
411    fn test() {
412        let ff = ConstModIntBuilder::<P>::new();
413
414        let a: Vec<_> = vec![5, 4, 3, 2, 1]
415            .into_iter()
416            .map(|x| ff.from_u64(x))
417            .collect();
418        let a = Polynomial::from(a);
419
420        let b: Vec<_> = vec![1, 2, 3, 4, 5]
421            .into_iter()
422            .map(|x| ff.from_u64(x))
423            .collect();
424        let b = Polynomial::from(b);
425
426        let (q, r) = a.clone().divrem(b.clone());
427
428        let a_ = q * b.clone() + r;
429        assert_eq!(a, a_);
430    }
431
432    #[test]
433    fn test_deg() {
434        let check = |a: Vec<usize>, d: Option<usize>| {
435            assert_eq!(Polynomial::<P>::from(a).deg(), d);
436        };
437
438        check(vec![1, 2, 3], Some(2));
439        check(vec![1, 2, 3, 0, 0, 0], Some(2));
440        check(vec![], None);
441        check(vec![0, 0, 0, 0], None);
442    }
443}