1use std::ops::{Add, AddAssign, Index, IndexMut, Sub, SubAssign};
3
4use crate::math::ntt::NTT;
5use crate::num::const_modint::*;
6
7#[derive(Clone, Debug, Default)]
9pub struct Polynomial<const P: u32> {
10 pub(crate) data: Vec<ConstModInt<P>>,
11}
12
13impl<const P: u32> Polynomial<P> {
14 pub fn zero() -> Self {
16 Self { data: vec![] }
17 }
18
19 pub fn constant(a: ConstModInt<P>) -> Self {
21 if a.value() == 0 {
22 Self::zero()
23 } else {
24 Self { data: vec![a] }
25 }
26 }
27
28 pub fn coeff_of(&self, i: usize) -> ConstModInt<P> {
30 self.data.get(i).map_or(ConstModInt::new(0), |a| *a)
31 }
32
33 pub fn eval(&self, p: ConstModInt<P>) -> ConstModInt<P> {
35 let mut ret = ConstModInt::new(0);
36 let mut x = ConstModInt::new(1);
37
38 for &a in &self.data {
39 ret += a * x;
40 x *= p;
41 }
42
43 ret
44 }
45
46 pub fn len(&self) -> usize {
48 self.data.len()
49 }
50
51 pub fn is_empty(&self) -> bool {
53 self.data.is_empty()
54 }
55
56 pub fn shrink(&mut self) {
58 while self.data.last().is_some_and(|x| x.value() == 0) {
59 self.data.pop();
60 }
61 }
62
63 pub fn get_until(&self, t: usize) -> Self {
65 Self {
66 data: self.data[..t.min(self.len())].to_vec(),
67 }
68 }
69
70 pub fn deg(&self) -> Option<usize> {
76 (0..self.len()).rev().find(|&i| self.data[i].value() != 0)
77 }
78
79 pub fn scale(&mut self, k: ConstModInt<P>) {
81 self.data.iter_mut().for_each(|x| *x *= k);
82 }
83
84 pub fn differentiate(&mut self) {
86 let n = self.len();
87 if n > 0 {
88 for i in 0..n - 1 {
89 self.data[i] = self.data[i + 1] * ConstModInt::new(i as u32 + 1);
90 }
91 self.data.pop();
92 }
93 }
94
95 pub fn integrate(&mut self) {
97 let n = self.len();
98 let mut invs = vec![ConstModInt::new(1); n + 1];
99 for i in 2..=n {
100 invs[i] = -invs[P as usize % i] * ConstModInt::new(P / i as u32);
101 }
102 self.data.push(0.into());
103 for i in (0..n).rev() {
104 self.data[i + 1] = self.data[i] * invs[i + 1];
105 }
106 self.data[0] = 0.into();
107 }
108
109 pub fn shift_higher(&mut self, k: usize) {
113 let n = self.len();
114 for i in (k..n).rev() {
115 self.data[i] = self.data[i - k];
116 }
117 for i in 0..k {
118 self.data[i] = 0.into();
119 }
120 }
121
122 pub fn shift_lower(&mut self, k: usize) {
124 let n = self.len();
125 for i in 0..n.saturating_sub(k) {
126 self.data[i] = self.data[i + k];
127 }
128 for i in n.saturating_sub(k)..n {
129 self.data[i] = 0.into();
130 }
131 }
132}
133
134impl<const P: u32> AddAssign for Polynomial<P> {
135 fn add_assign(&mut self, b: Polynomial<P>) {
136 if self.len() < b.len() {
137 self.data.resize(b.len(), ConstModInt::new(0));
138 }
139 for (a, b) in self.data.iter_mut().zip(b.data) {
140 *a += b;
141 }
142 }
143}
144
145impl<const P: u32> Add for Polynomial<P> {
146 type Output = Self;
147 fn add(mut self, b: Polynomial<P>) -> Polynomial<P> {
148 self += b;
149 self
150 }
151}
152
153impl<const P: u32> SubAssign for Polynomial<P> {
154 fn sub_assign(&mut self, b: Polynomial<P>) {
155 if self.len() < b.len() {
156 self.data.resize(b.len(), ConstModInt::new(0));
157 }
158 for (a, b) in self.data.iter_mut().zip(b.data) {
159 *a -= b;
160 }
161 }
162}
163
164impl<const P: u32> Sub for Polynomial<P> {
165 type Output = Self;
166 fn sub(mut self, b: Polynomial<P>) -> Polynomial<P> {
167 self -= b;
168 self
169 }
170}
171
172impl<const P: u32> PartialEq for Polynomial<P> {
173 fn eq(&self, other: &Self) -> bool {
174 let n = self.len().max(other.len());
175 for i in 0..n {
176 if self.coeff_of(i) != other.coeff_of(i) {
177 return false;
178 }
179 }
180 true
181 }
182}
183
184impl<const P: u32> From<Polynomial<P>> for Vec<ConstModInt<P>> {
185 fn from(value: Polynomial<P>) -> Self {
186 value.data
187 }
188}
189
190impl<T, const P: u32> From<Vec<T>> for Polynomial<P>
191where
192 T: Into<ConstModInt<P>>,
193{
194 fn from(value: Vec<T>) -> Self {
195 Self {
196 data: value.into_iter().map(Into::into).collect(),
197 }
198 }
199}
200
201impl<const P: u32> AsRef<[ConstModInt<P>]> for Polynomial<P> {
202 fn as_ref(&self) -> &[ConstModInt<P>] {
203 &self.data
204 }
205}
206
207impl<const P: u32> AsMut<Vec<ConstModInt<P>>> for Polynomial<P> {
208 fn as_mut(&mut self) -> &mut Vec<ConstModInt<P>> {
209 &mut self.data
210 }
211}
212
213impl<const P: u32> Index<usize> for Polynomial<P> {
214 type Output = ConstModInt<P>;
215 fn index(&self, index: usize) -> &Self::Output {
216 &self.data[index]
217 }
218}
219
220impl<const P: u32> IndexMut<usize> for Polynomial<P> {
221 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
222 &mut self.data[index]
223 }
224}
225
226pub struct PolynomialOperator<'a, const P: u32, const PR: u32> {
228 pub(crate) ntt: &'a NTT<P, PR>,
229}
230
231impl<'a, const P: u32, const PR: u32> PolynomialOperator<'a, P, PR> {
232 pub fn new(ntt: &'a NTT<P, PR>) -> Self {
234 Self { ntt }
235 }
236
237 pub fn mul_assign(&self, a: &mut Polynomial<P>, mut b: Polynomial<P>) {
239 let k = a.len() + b.len() - 1;
240
241 let n = k.next_power_of_two();
242 a.data.resize(n, 0.into());
243 self.ntt.ntt(&mut a.data);
244
245 b.data.resize(n, 0.into());
246 self.ntt.ntt(&mut b.data);
247
248 a.data.iter_mut().zip(b.data).for_each(|(x, y)| *x *= y);
249 self.ntt.intt(&mut a.data);
250
251 a.data.truncate(k);
252 }
253
254 pub fn mul(&self, mut a: Polynomial<P>, b: Polynomial<P>) -> Polynomial<P> {
256 self.mul_assign(&mut a, b);
257 a
258 }
259
260 pub fn sq(&self, mut a: Polynomial<P>) -> Polynomial<P> {
262 let k = a.len() * 2 - 1;
263 let n = k.next_power_of_two();
264
265 a.data.resize(n, 0.into());
266 self.ntt.ntt(&mut a.data);
267 a.data.iter_mut().for_each(|x| *x *= *x);
268 self.ntt.intt(&mut a.data);
269
270 a.data.truncate(k);
271 a
272 }
273
274 #[allow(missing_docs)]
275 pub fn inv(&self, a: Polynomial<P>, n: usize) -> Polynomial<P> {
276 let mut t = 1;
277 let mut ret = vec![a.data[0].inv()];
278 let a: Vec<_> = a.into();
279
280 while t <= n * 2 {
281 let k = (t * 2 - 1).next_power_of_two();
282
283 let mut s = ret.clone();
284 s.resize(k, 0.into());
285 self.ntt.ntt(&mut s);
286 s.iter_mut().for_each(|x| *x *= *x);
287
288 let mut a = a[..t.min(a.len())].to_vec();
289 a.resize(k, 0.into());
290 self.ntt.ntt(&mut a);
291
292 s.iter_mut().zip(a).for_each(|(x, y)| *x *= y);
293 self.ntt.intt(&mut s);
294
295 ret.resize(t, 0.into());
296 ret.iter_mut()
297 .zip(s)
298 .for_each(|(x, y)| *x = *x * 2.into() - y);
299
300 t *= 2;
301 }
302
303 ret.into()
304 }
305
306 pub fn div(&self, mut a: Polynomial<P>, mut b: Polynomial<P>) -> Polynomial<P> {
308 if a.len() < b.len() {
309 return Polynomial::zero();
310 }
311
312 let m = a.len() - b.len();
313
314 a.data.reverse();
315 b.data.reverse();
316
317 b = self.inv(b, m);
318 b.data.resize(m + 1, 0.into());
319
320 let mut q = self.mul(a, b);
321 q.data.resize(m + 1, 0.into());
322 q.data.reverse();
323 q.shrink();
324 q
325 }
326
327 pub fn rem(&self, a: Polynomial<P>, b: Polynomial<P>) -> Polynomial<P> {
329 self.divrem(a, b).1
330 }
331
332 pub fn divrem(&self, a: Polynomial<P>, b: Polynomial<P>) -> (Polynomial<P>, Polynomial<P>) {
334 if a.len() < b.len() {
335 return (Polynomial::zero(), a);
336 }
337
338 let q = self.div(a.clone(), b.clone());
339
340 let d = b.len() - 1;
341 let mut r = a.sub(self.mul(b, q.clone()));
342 r.data.truncate(d);
343 r.shrink();
344
345 (q, r)
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use crate::num::const_modint::ConstModIntBuilder;
352
353 use super::*;
354
355 const M: u32 = 998244353;
356
357 #[test]
358 fn test() {
359 let ff = ConstModIntBuilder::<M>;
360 let ntt = NTT::<M, 3>::new();
361 let po = PolynomialOperator::new(&ntt);
362
363 let a: Vec<_> = vec![5, 4, 3, 2, 1]
364 .into_iter()
365 .map(|x| ff.from_u64(x))
366 .collect();
367 let a = Polynomial::from(a);
368
369 let b: Vec<_> = vec![1, 2, 3, 4, 5]
370 .into_iter()
371 .map(|x| ff.from_u64(x))
372 .collect();
373 let b = Polynomial::from(b);
374
375 let (q, r) = po.divrem(a.clone(), b.clone());
376
377 let a_ = po.mul(q, b.clone()) + r;
378 assert_eq!(a, a_);
379 }
380
381 #[test]
382 fn test_deg() {
383 let check = |a: Vec<usize>, d: Option<usize>| {
384 assert_eq!(Polynomial::<M>::from(a).deg(), d);
385 };
386
387 check(vec![1, 2, 3], Some(2));
388 check(vec![1, 2, 3, 0, 0, 0], Some(2));
389 check(vec![], None);
390 check(vec![0, 0, 0, 0], None);
391 }
392}