haar_lib/math/
polynomial.rs

1//! $\mathbb{F}_p$上の多項式
2use std::ops::{Add, AddAssign, Index, IndexMut, Sub, SubAssign};
3
4use crate::math::ntt::NTT;
5use crate::num::const_modint::*;
6
7/// $\mathbb{F}_p$上の多項式
8#[derive(Clone, Debug, Default)]
9pub struct Polynomial<const P: u32> {
10    pub(crate) data: Vec<ConstModInt<P>>,
11}
12
13impl<const P: u32> Polynomial<P> {
14    /// 零多項式を得る。
15    pub fn zero() -> Self {
16        Self { data: vec![] }
17    }
18
19    /// 定数項のみをもつ多項式を生成する。
20    pub fn constant(a: ConstModInt<P>) -> Self {
21        if a.value() == 0 {
22            Self::zero()
23        } else {
24            Self { data: vec![a] }
25        }
26    }
27
28    /// $x^i$の係数を得る。
29    pub fn coeff_of(&self, i: usize) -> ConstModInt<P> {
30        self.data.get(i).map_or(ConstModInt::new(0), |a| *a)
31    }
32
33    /// 多項式に値`p`を代入した結果を求める。
34    pub fn eval(&self, p: ConstModInt<P>) -> ConstModInt<P> {
35        let mut ret = ConstModInt::new(0);
36        let mut x = ConstModInt::new(1);
37
38        for &a in &self.data {
39            ret += a * x;
40            x *= p;
41        }
42
43        ret
44    }
45
46    /// 内部の`Vec`の長さを返す。
47    pub fn len(&self) -> usize {
48        self.data.len()
49    }
50
51    /// 項数が`0`のとき`true`を返す。
52    pub fn is_empty(&self) -> bool {
53        self.data.is_empty()
54    }
55
56    /// 係数が`0`の高次項を縮める。
57    pub fn shrink(&mut self) {
58        while self.data.last().is_some_and(|x| x.value() == 0) {
59            self.data.pop();
60        }
61    }
62
63    /// [`len()`](Self::len())を超えないように、先頭`t`項をもつ多項式を返す。
64    pub fn get_until(&self, t: usize) -> Self {
65        Self {
66            data: self.data[..t.min(self.len())].to_vec(),
67        }
68    }
69
70    /// 多項式の次数を返す。
71    ///
72    /// `self`が零多項式のときは`None`を返す。
73    ///
74    /// **Time complexity** $O(n)$
75    pub fn deg(&self) -> Option<usize> {
76        (0..self.len()).rev().find(|&i| self.data[i].value() != 0)
77    }
78
79    /// 多項式を`k`倍する。
80    pub fn scale(&mut self, k: ConstModInt<P>) {
81        self.data.iter_mut().for_each(|x| *x *= k);
82    }
83
84    /// 多項式を微分する。
85    pub fn differentiate(&mut self) {
86        let n = self.len();
87        if n > 0 {
88            for i in 0..n - 1 {
89                self.data[i] = self.data[i + 1] * ConstModInt::new(i as u32 + 1);
90            }
91            self.data.pop();
92        }
93    }
94
95    /// 多項式を積分する。
96    pub fn integrate(&mut self) {
97        let n = self.len();
98        let mut invs = vec![ConstModInt::new(1); n + 1];
99        for i in 2..=n {
100            invs[i] = -invs[P as usize % i] * ConstModInt::new(P / i as u32);
101        }
102        self.data.push(0.into());
103        for i in (0..n).rev() {
104            self.data[i + 1] = self.data[i] * invs[i + 1];
105        }
106        self.data[0] = 0.into();
107    }
108
109    /// 係数を`k`次だけ高次側にずらす。ただし、$x^n$の項以降は無視する。
110    ///
111    /// $(a_0 + a_1 x + a_2 x^2 + \ldots + a_{n-1} x^{n-1}) \times x^k \pmod {x^n}$
112    pub fn shift_higher(&mut self, k: usize) {
113        let n = self.len();
114        for i in (k..n).rev() {
115            self.data[i] = self.data[i - k];
116        }
117        for i in 0..k {
118            self.data[i] = 0.into();
119        }
120    }
121
122    /// 係数を`k`次だけ低次側にずらす。ただし、負の次数の項は無視する。
123    pub fn shift_lower(&mut self, k: usize) {
124        let n = self.len();
125        for i in 0..n.saturating_sub(k) {
126            self.data[i] = self.data[i + k];
127        }
128        for i in n.saturating_sub(k)..n {
129            self.data[i] = 0.into();
130        }
131    }
132}
133
134impl<const P: u32> AddAssign for Polynomial<P> {
135    fn add_assign(&mut self, b: Polynomial<P>) {
136        if self.len() < b.len() {
137            self.data.resize(b.len(), ConstModInt::new(0));
138        }
139        for (a, b) in self.data.iter_mut().zip(b.data) {
140            *a += b;
141        }
142    }
143}
144
145impl<const P: u32> Add for Polynomial<P> {
146    type Output = Self;
147    fn add(mut self, b: Polynomial<P>) -> Polynomial<P> {
148        self += b;
149        self
150    }
151}
152
153impl<const P: u32> SubAssign for Polynomial<P> {
154    fn sub_assign(&mut self, b: Polynomial<P>) {
155        if self.len() < b.len() {
156            self.data.resize(b.len(), ConstModInt::new(0));
157        }
158        for (a, b) in self.data.iter_mut().zip(b.data) {
159            *a -= b;
160        }
161    }
162}
163
164impl<const P: u32> Sub for Polynomial<P> {
165    type Output = Self;
166    fn sub(mut self, b: Polynomial<P>) -> Polynomial<P> {
167        self -= b;
168        self
169    }
170}
171
172impl<const P: u32> PartialEq for Polynomial<P> {
173    fn eq(&self, other: &Self) -> bool {
174        let n = self.len().max(other.len());
175        for i in 0..n {
176            if self.coeff_of(i) != other.coeff_of(i) {
177                return false;
178            }
179        }
180        true
181    }
182}
183
184impl<const P: u32> From<Polynomial<P>> for Vec<ConstModInt<P>> {
185    fn from(value: Polynomial<P>) -> Self {
186        value.data
187    }
188}
189
190impl<T, const P: u32> From<Vec<T>> for Polynomial<P>
191where
192    T: Into<ConstModInt<P>>,
193{
194    fn from(value: Vec<T>) -> Self {
195        Self {
196            data: value.into_iter().map(Into::into).collect(),
197        }
198    }
199}
200
201impl<const P: u32> AsRef<[ConstModInt<P>]> for Polynomial<P> {
202    fn as_ref(&self) -> &[ConstModInt<P>] {
203        &self.data
204    }
205}
206
207impl<const P: u32> AsMut<Vec<ConstModInt<P>>> for Polynomial<P> {
208    fn as_mut(&mut self) -> &mut Vec<ConstModInt<P>> {
209        &mut self.data
210    }
211}
212
213impl<const P: u32> Index<usize> for Polynomial<P> {
214    type Output = ConstModInt<P>;
215    fn index(&self, index: usize) -> &Self::Output {
216        &self.data[index]
217    }
218}
219
220impl<const P: u32> IndexMut<usize> for Polynomial<P> {
221    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
222        &mut self.data[index]
223    }
224}
225
226/// 多項式の演算を扱う。
227pub struct PolynomialOperator<'a, const P: u32, const PR: u32> {
228    pub(crate) ntt: &'a NTT<P, PR>,
229}
230
231impl<'a, const P: u32, const PR: u32> PolynomialOperator<'a, P, PR> {
232    /// [`NTT<P>`]を基に`PolynomialOperator<P>`を生成する。
233    pub fn new(ntt: &'a NTT<P, PR>) -> Self {
234        Self { ntt }
235    }
236
237    /// 多項式`a`に多項式`b`を掛ける。
238    pub fn mul_assign(&self, a: &mut Polynomial<P>, mut b: Polynomial<P>) {
239        let k = a.len() + b.len() - 1;
240
241        let n = k.next_power_of_two();
242        a.data.resize(n, 0.into());
243        self.ntt.ntt(&mut a.data);
244
245        b.data.resize(n, 0.into());
246        self.ntt.ntt(&mut b.data);
247
248        a.data.iter_mut().zip(b.data).for_each(|(x, y)| *x *= y);
249        self.ntt.intt(&mut a.data);
250
251        a.data.truncate(k);
252    }
253
254    /// 多項式`a`と多項式`b`の積を返す。
255    pub fn mul(&self, mut a: Polynomial<P>, b: Polynomial<P>) -> Polynomial<P> {
256        self.mul_assign(&mut a, b);
257        a
258    }
259
260    /// 多項式`a`の2乗を返す。
261    pub fn sq(&self, mut a: Polynomial<P>) -> Polynomial<P> {
262        let k = a.len() * 2 - 1;
263        let n = k.next_power_of_two();
264
265        a.data.resize(n, 0.into());
266        self.ntt.ntt(&mut a.data);
267        a.data.iter_mut().for_each(|x| *x *= *x);
268        self.ntt.intt(&mut a.data);
269
270        a.data.truncate(k);
271        a
272    }
273
274    #[allow(missing_docs)]
275    pub fn inv(&self, a: Polynomial<P>, n: usize) -> Polynomial<P> {
276        let mut t = 1;
277        let mut ret = vec![a.data[0].inv()];
278        let a: Vec<_> = a.into();
279
280        while t <= n * 2 {
281            let k = (t * 2 - 1).next_power_of_two();
282
283            let mut s = ret.clone();
284            s.resize(k, 0.into());
285            self.ntt.ntt(&mut s);
286            s.iter_mut().for_each(|x| *x *= *x);
287
288            let mut a = a[..t.min(a.len())].to_vec();
289            a.resize(k, 0.into());
290            self.ntt.ntt(&mut a);
291
292            s.iter_mut().zip(a).for_each(|(x, y)| *x *= y);
293            self.ntt.intt(&mut s);
294
295            ret.resize(t, 0.into());
296            ret.iter_mut()
297                .zip(s)
298                .for_each(|(x, y)| *x = *x * 2.into() - y);
299
300            t *= 2;
301        }
302
303        ret.into()
304    }
305
306    /// 多項式`a`の多項式`b`による商を返す。
307    pub fn div(&self, mut a: Polynomial<P>, mut b: Polynomial<P>) -> Polynomial<P> {
308        if a.len() < b.len() {
309            return Polynomial::zero();
310        }
311
312        let m = a.len() - b.len();
313
314        a.data.reverse();
315        b.data.reverse();
316
317        b = self.inv(b, m);
318        b.data.resize(m + 1, 0.into());
319
320        let mut q = self.mul(a, b);
321        q.data.resize(m + 1, 0.into());
322        q.data.reverse();
323        q.shrink();
324        q
325    }
326
327    /// 多項式`a`の多項式`b`による剰余を返す。
328    pub fn rem(&self, a: Polynomial<P>, b: Polynomial<P>) -> Polynomial<P> {
329        self.divrem(a, b).1
330    }
331
332    /// 多項式`a`の多項式`b`による商と剰余を返す。
333    pub fn divrem(&self, a: Polynomial<P>, b: Polynomial<P>) -> (Polynomial<P>, Polynomial<P>) {
334        if a.len() < b.len() {
335            return (Polynomial::zero(), a);
336        }
337
338        let q = self.div(a.clone(), b.clone());
339
340        let d = b.len() - 1;
341        let mut r = a.sub(self.mul(b, q.clone()));
342        r.data.truncate(d);
343        r.shrink();
344
345        (q, r)
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use crate::num::const_modint::ConstModIntBuilder;
352
353    use super::*;
354
355    const M: u32 = 998244353;
356
357    #[test]
358    fn test() {
359        let ff = ConstModIntBuilder::<M>;
360        let ntt = NTT::<M, 3>::new();
361        let po = PolynomialOperator::new(&ntt);
362
363        let a: Vec<_> = vec![5, 4, 3, 2, 1]
364            .into_iter()
365            .map(|x| ff.from_u64(x))
366            .collect();
367        let a = Polynomial::from(a);
368
369        let b: Vec<_> = vec![1, 2, 3, 4, 5]
370            .into_iter()
371            .map(|x| ff.from_u64(x))
372            .collect();
373        let b = Polynomial::from(b);
374
375        let (q, r) = po.divrem(a.clone(), b.clone());
376
377        let a_ = po.mul(q, b.clone()) + r;
378        assert_eq!(a, a_);
379    }
380
381    #[test]
382    fn test_deg() {
383        let check = |a: Vec<usize>, d: Option<usize>| {
384            assert_eq!(Polynomial::<M>::from(a).deg(), d);
385        };
386
387        check(vec![1, 2, 3], Some(2));
388        check(vec![1, 2, 3, 0, 0, 0], Some(2));
389        check(vec![], None);
390        check(vec![0, 0, 0, 0], None);
391    }
392}