haar_lib/math/
polynomial.rs

1//! $\mathbb{F}_p$上の多項式
2use std::ops::{Index, IndexMut};
3
4use crate::math::ntt::NTT;
5use crate::num::const_modint::*;
6
7/// $\mathbb{F}_p$上の多項式
8#[derive(Clone, Debug)]
9pub struct Polynomial<const P: u32> {
10    data: Vec<ConstModInt<P>>,
11}
12
13impl<const P: u32> Polynomial<P> {
14    /// 零多項式を得る。
15    pub fn zero() -> Self {
16        Self { data: vec![] }
17    }
18
19    /// 定数項のみをもつ多項式を生成する。
20    pub fn constant(a: ConstModInt<P>) -> Self {
21        if a.value() == 0 {
22            Self::zero()
23        } else {
24            Self { data: vec![a] }
25        }
26    }
27
28    /// $x^i$の係数を得る。
29    pub fn coeff_of(&self, i: usize) -> ConstModInt<P> {
30        self.data.get(i).map_or(ConstModInt::new(0), |a| *a)
31    }
32
33    /// 多項式に値`p`を代入した結果を求める。
34    pub fn eval(&self, p: ConstModInt<P>) -> ConstModInt<P> {
35        let mut ret = ConstModInt::new(0);
36        let mut x = ConstModInt::new(1);
37
38        for &a in &self.data {
39            ret += a * x;
40            x *= p;
41        }
42
43        ret
44    }
45
46    /// 内部の`Vec`の長さを返す。
47    pub fn len(&self) -> usize {
48        self.data.len()
49    }
50
51    /// 項数が`0`のとき`true`を返す。
52    pub fn is_empty(&self) -> bool {
53        self.data.is_empty()
54    }
55
56    /// 係数が`0`の高次項を縮める。
57    pub fn shrink(&mut self) {
58        while self.data.last().is_some_and(|x| x.value() == 0) {
59            self.data.pop();
60        }
61    }
62
63    /// [`len()`](Self::len())を超えないように、先頭`t`項をもつ多項式を返す。
64    pub fn get_until(&self, t: usize) -> Self {
65        Self {
66            data: self.data[..t.min(self.len())].to_vec(),
67        }
68    }
69
70    /// 多項式の次数を返す。
71    ///
72    /// `self`が零多項式のときは`None`を返す。
73    ///
74    /// **Time complexity** $O(n)$
75    pub fn deg(&self) -> Option<usize> {
76        (0..self.len()).rev().find(|&i| self.data[i].value() != 0)
77    }
78}
79
80impl<const P: u32> PartialEq for Polynomial<P> {
81    fn eq(&self, other: &Self) -> bool {
82        let n = self.len().max(other.len());
83        for i in 0..n {
84            if self.coeff_of(i) != other.coeff_of(i) {
85                return false;
86            }
87        }
88        true
89    }
90}
91
92impl<const P: u32> From<Polynomial<P>> for Vec<ConstModInt<P>> {
93    fn from(value: Polynomial<P>) -> Self {
94        value.data
95    }
96}
97
98impl<T, const P: u32> From<Vec<T>> for Polynomial<P>
99where
100    T: Into<ConstModInt<P>>,
101{
102    fn from(value: Vec<T>) -> Self {
103        Self {
104            data: value.into_iter().map(Into::into).collect(),
105        }
106    }
107}
108
109impl<const P: u32> AsRef<[ConstModInt<P>]> for Polynomial<P> {
110    fn as_ref(&self) -> &[ConstModInt<P>] {
111        &self.data
112    }
113}
114
115impl<const P: u32> AsMut<Vec<ConstModInt<P>>> for Polynomial<P> {
116    fn as_mut(&mut self) -> &mut Vec<ConstModInt<P>> {
117        &mut self.data
118    }
119}
120
121impl<const P: u32> Index<usize> for Polynomial<P> {
122    type Output = ConstModInt<P>;
123    fn index(&self, index: usize) -> &Self::Output {
124        &self.data[index]
125    }
126}
127
128impl<const P: u32> IndexMut<usize> for Polynomial<P> {
129    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
130        &mut self.data[index]
131    }
132}
133
134/// 多項式の演算を扱う。
135pub struct PolynomialOperator<'a, const P: u32, const PR: u32> {
136    pub(crate) ntt: &'a NTT<P, PR>,
137}
138
139impl<'a, const P: u32, const PR: u32> PolynomialOperator<'a, P, PR> {
140    /// [`NTT<P>`]を基に`PolynomialOperator<P>`を生成する。
141    pub fn new(ntt: &'a NTT<P, PR>) -> Self {
142        Self { ntt }
143    }
144
145    /// 多項式`a`に多項式`b`を足す。
146    pub fn add_assign(&self, a: &mut Polynomial<P>, b: Polynomial<P>) {
147        if a.len() < b.len() {
148            a.data.resize(b.len(), ConstModInt::new(0));
149        }
150        for (a, b) in a.data.iter_mut().zip(b.data.into_iter()) {
151            *a += b;
152        }
153    }
154
155    /// 多項式`a`と多項式`b`の和を返す。
156    pub fn add(&self, mut a: Polynomial<P>, b: Polynomial<P>) -> Polynomial<P> {
157        self.add_assign(&mut a, b);
158        a
159    }
160
161    /// 多項式`a`から多項式`b`を引く。
162    pub fn sub_assign(&self, a: &mut Polynomial<P>, b: Polynomial<P>) {
163        if a.len() < b.len() {
164            a.data.resize(b.len(), ConstModInt::new(0));
165        }
166        for (a, b) in a.data.iter_mut().zip(b.data.into_iter()) {
167            *a -= b;
168        }
169    }
170
171    /// 多項式`a`と多項式`b`の差を返す。
172    pub fn sub(&self, mut a: Polynomial<P>, b: Polynomial<P>) -> Polynomial<P> {
173        self.sub_assign(&mut a, b);
174        a
175    }
176
177    /// 多項式`a`に多項式`b`を掛ける。
178    pub fn mul_assign(&self, a: &mut Polynomial<P>, b: Polynomial<P>) {
179        let k = a.len() + b.len() - 1;
180        a.data = self.ntt.convolve(a.data.clone(), b.data);
181        a.data.truncate(k);
182    }
183
184    /// 多項式`a`と多項式`b`の積を返す。
185    pub fn mul(&self, mut a: Polynomial<P>, b: Polynomial<P>) -> Polynomial<P> {
186        self.mul_assign(&mut a, b);
187        a
188    }
189
190    /// 多項式`a`の2乗を返す。
191    pub fn sq(&self, a: Polynomial<P>) -> Polynomial<P> {
192        self.mul(a.clone(), a)
193    }
194
195    /// 多項式`a`の`k`倍を返す。
196    pub fn scale(&self, a: Polynomial<P>, k: ConstModInt<P>) -> Polynomial<P> {
197        Polynomial {
198            data: a.data.into_iter().map(|x| x * k).collect(),
199        }
200    }
201
202    #[allow(missing_docs)]
203    pub fn inv(&self, a: Polynomial<P>, n: usize) -> Polynomial<P> {
204        let mut ret = Polynomial::constant(a.data[0].inv());
205        let mut t = 1;
206
207        while t <= n * 2 {
208            ret = self.sub(
209                self.scale(ret.clone(), ConstModInt::new(2)),
210                self.mul(self.sq(ret).get_until(t), a.clone().get_until(t)),
211            );
212            ret.data.truncate(t);
213            t *= 2;
214        }
215
216        ret
217    }
218
219    /// 多項式`a`の多項式`b`による商と剰余を返す。
220    pub fn divmod(&self, a: Polynomial<P>, b: Polynomial<P>) -> (Polynomial<P>, Polynomial<P>) {
221        if a.len() < b.len() {
222            return (Polynomial::zero(), a);
223        }
224
225        let m = a.len() - b.len();
226
227        let mut g = a.clone();
228        g.data.reverse();
229
230        let mut f = b.clone();
231        f.data.reverse();
232
233        f = self.inv(f, m);
234        f.data.resize(m + 1, ConstModInt::new(0));
235
236        let mut q = self.mul(f, g);
237        q.data.resize(m + 1, ConstModInt::new(0));
238        q.data.reverse();
239
240        let d = b.len() - 1;
241        let mut r = self.sub(a, self.mul(b, q.clone()));
242        r.data.truncate(d);
243
244        r.shrink();
245        q.shrink();
246
247        (q, r)
248    }
249
250    /// 多項式の微分を返す。
251    pub fn differentiate(&self, a: Polynomial<P>) -> Polynomial<P> {
252        let mut a: Vec<_> = a.into();
253        let n = a.len();
254        if n > 0 {
255            for i in 0..n - 1 {
256                a[i] = a[i + 1] * ConstModInt::new(i as u32 + 1);
257            }
258            a.pop();
259        }
260        a.into()
261    }
262
263    /// 多項式の積分を返す。
264    pub fn integrate(&self, a: Polynomial<P>) -> Polynomial<P> {
265        let mut a: Vec<_> = a.into();
266        let n = a.len();
267        let mut invs = vec![ConstModInt::new(1); n + 1];
268        for i in 2..=n {
269            invs[i] = -invs[P as usize % i] * ConstModInt::new(P / i as u32);
270        }
271        a.push(ConstModInt::new(0));
272        for i in (0..n).rev() {
273            a[i + 1] = a[i] * invs[i + 1];
274        }
275        a[0] = ConstModInt::new(0);
276
277        a.into()
278    }
279
280    /// 係数を`k`次だけ高次側にずらす。ただし、$x^n$の項以降は無視する。
281    ///
282    /// $(a_0 + a_1 x + a_2 x^2 + \ldots + a_{n-1} x^{n-1}) \times x^k \pmod {x^n}$
283    pub fn shift_higher(&self, a: Polynomial<P>, k: usize) -> Polynomial<P> {
284        let a: Vec<_> = a.into();
285        let n = a.len();
286        let mut ret = vec![ConstModInt::new(0); n];
287
288        ret[k..n].copy_from_slice(&a[..(n - k)]);
289
290        ret.into()
291    }
292
293    /// 係数を`k`次だけ低次側にずらす。ただし、負の次数の項は無視する。
294    pub fn shift_lower(&self, a: Polynomial<P>, k: usize) -> Polynomial<P> {
295        let a: Vec<_> = a.into();
296        let n = a.len();
297        let mut ret = vec![ConstModInt::new(0); n];
298
299        for i in (0..n.saturating_sub(k)).rev() {
300            ret[i] = a[i + k];
301        }
302
303        ret.into()
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use crate::num::const_modint::ConstModIntBuilder;
310
311    use super::*;
312
313    const M: u32 = 998244353;
314
315    #[test]
316    fn test() {
317        let ff = ConstModIntBuilder::<M>;
318        let ntt = NTT::<M, 3>::new();
319        let po = PolynomialOperator::new(&ntt);
320
321        let a: Vec<_> = vec![5, 4, 3, 2, 1]
322            .into_iter()
323            .map(|x| ff.from_u64(x))
324            .collect();
325        let a = Polynomial::from(a);
326
327        let b: Vec<_> = vec![1, 2, 3, 4, 5]
328            .into_iter()
329            .map(|x| ff.from_u64(x))
330            .collect();
331        let b = Polynomial::from(b);
332
333        let (q, r) = po.divmod(a.clone(), b.clone());
334
335        let a_ = po.add(po.mul(q, b.clone()), r);
336        assert_eq!(a, a_);
337    }
338
339    #[test]
340    fn test_deg() {
341        let check = |a: Vec<usize>, d: Option<usize>| {
342            assert_eq!(Polynomial::<M>::from(a).deg(), d);
343        };
344
345        check(vec![1, 2, 3], Some(2));
346        check(vec![1, 2, 3, 0, 0, 0], Some(2));
347        check(vec![], None);
348        check(vec![0, 0, 0, 0], None);
349    }
350}