1use std::ops::{Index, IndexMut};
3
4use crate::math::ntt::NTT;
5use crate::num::const_modint::*;
6
7#[derive(Clone, Debug)]
9pub struct Polynomial<const P: u32> {
10 data: Vec<ConstModInt<P>>,
11}
12
13impl<const P: u32> Polynomial<P> {
14 pub fn zero() -> Self {
16 Self { data: vec![] }
17 }
18
19 pub fn constant(a: ConstModInt<P>) -> Self {
21 if a.value() == 0 {
22 Self::zero()
23 } else {
24 Self { data: vec![a] }
25 }
26 }
27
28 pub fn coeff_of(&self, i: usize) -> ConstModInt<P> {
30 self.data.get(i).map_or(ConstModInt::new(0), |a| *a)
31 }
32
33 pub fn eval(&self, p: ConstModInt<P>) -> ConstModInt<P> {
35 let mut ret = ConstModInt::new(0);
36 let mut x = ConstModInt::new(1);
37
38 for &a in &self.data {
39 ret += a * x;
40 x *= p;
41 }
42
43 ret
44 }
45
46 pub fn len(&self) -> usize {
48 self.data.len()
49 }
50
51 pub fn is_empty(&self) -> bool {
53 self.data.is_empty()
54 }
55
56 pub fn shrink(&mut self) {
58 while self.data.last().is_some_and(|x| x.value() == 0) {
59 self.data.pop();
60 }
61 }
62
63 pub fn get_until(&self, t: usize) -> Self {
65 Self {
66 data: self.data[..t.min(self.len())].to_vec(),
67 }
68 }
69
70 pub fn deg(&self) -> Option<usize> {
76 (0..self.len()).rev().find(|&i| self.data[i].value() != 0)
77 }
78}
79
80impl<const P: u32> PartialEq for Polynomial<P> {
81 fn eq(&self, other: &Self) -> bool {
82 let n = self.len().max(other.len());
83 for i in 0..n {
84 if self.coeff_of(i) != other.coeff_of(i) {
85 return false;
86 }
87 }
88 true
89 }
90}
91
92impl<const P: u32> From<Polynomial<P>> for Vec<ConstModInt<P>> {
93 fn from(value: Polynomial<P>) -> Self {
94 value.data
95 }
96}
97
98impl<T, const P: u32> From<Vec<T>> for Polynomial<P>
99where
100 T: Into<ConstModInt<P>>,
101{
102 fn from(value: Vec<T>) -> Self {
103 Self {
104 data: value.into_iter().map(Into::into).collect(),
105 }
106 }
107}
108
109impl<const P: u32> AsRef<[ConstModInt<P>]> for Polynomial<P> {
110 fn as_ref(&self) -> &[ConstModInt<P>] {
111 &self.data
112 }
113}
114
115impl<const P: u32> AsMut<Vec<ConstModInt<P>>> for Polynomial<P> {
116 fn as_mut(&mut self) -> &mut Vec<ConstModInt<P>> {
117 &mut self.data
118 }
119}
120
121impl<const P: u32> Index<usize> for Polynomial<P> {
122 type Output = ConstModInt<P>;
123 fn index(&self, index: usize) -> &Self::Output {
124 &self.data[index]
125 }
126}
127
128impl<const P: u32> IndexMut<usize> for Polynomial<P> {
129 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
130 &mut self.data[index]
131 }
132}
133
134pub struct PolynomialOperator<'a, const P: u32, const PR: u32> {
136 pub(crate) ntt: &'a NTT<P, PR>,
137}
138
139impl<'a, const P: u32, const PR: u32> PolynomialOperator<'a, P, PR> {
140 pub fn new(ntt: &'a NTT<P, PR>) -> Self {
142 Self { ntt }
143 }
144
145 pub fn add_assign(&self, a: &mut Polynomial<P>, b: Polynomial<P>) {
147 if a.len() < b.len() {
148 a.data.resize(b.len(), ConstModInt::new(0));
149 }
150 for (a, b) in a.data.iter_mut().zip(b.data.into_iter()) {
151 *a += b;
152 }
153 }
154
155 pub fn add(&self, mut a: Polynomial<P>, b: Polynomial<P>) -> Polynomial<P> {
157 self.add_assign(&mut a, b);
158 a
159 }
160
161 pub fn sub_assign(&self, a: &mut Polynomial<P>, b: Polynomial<P>) {
163 if a.len() < b.len() {
164 a.data.resize(b.len(), ConstModInt::new(0));
165 }
166 for (a, b) in a.data.iter_mut().zip(b.data.into_iter()) {
167 *a -= b;
168 }
169 }
170
171 pub fn sub(&self, mut a: Polynomial<P>, b: Polynomial<P>) -> Polynomial<P> {
173 self.sub_assign(&mut a, b);
174 a
175 }
176
177 pub fn mul_assign(&self, a: &mut Polynomial<P>, b: Polynomial<P>) {
179 let k = a.len() + b.len() - 1;
180 a.data = self.ntt.convolve(a.data.clone(), b.data);
181 a.data.truncate(k);
182 }
183
184 pub fn mul(&self, mut a: Polynomial<P>, b: Polynomial<P>) -> Polynomial<P> {
186 self.mul_assign(&mut a, b);
187 a
188 }
189
190 pub fn sq(&self, a: Polynomial<P>) -> Polynomial<P> {
192 self.mul(a.clone(), a)
193 }
194
195 pub fn scale(&self, a: Polynomial<P>, k: ConstModInt<P>) -> Polynomial<P> {
197 Polynomial {
198 data: a.data.into_iter().map(|x| x * k).collect(),
199 }
200 }
201
202 #[allow(missing_docs)]
203 pub fn inv(&self, a: Polynomial<P>, n: usize) -> Polynomial<P> {
204 let mut ret = Polynomial::constant(a.data[0].inv());
205 let mut t = 1;
206
207 while t <= n * 2 {
208 ret = self.sub(
209 self.scale(ret.clone(), ConstModInt::new(2)),
210 self.mul(self.sq(ret).get_until(t), a.clone().get_until(t)),
211 );
212 ret.data.truncate(t);
213 t *= 2;
214 }
215
216 ret
217 }
218
219 pub fn divmod(&self, a: Polynomial<P>, b: Polynomial<P>) -> (Polynomial<P>, Polynomial<P>) {
221 if a.len() < b.len() {
222 return (Polynomial::zero(), a);
223 }
224
225 let m = a.len() - b.len();
226
227 let mut g = a.clone();
228 g.data.reverse();
229
230 let mut f = b.clone();
231 f.data.reverse();
232
233 f = self.inv(f, m);
234 f.data.resize(m + 1, ConstModInt::new(0));
235
236 let mut q = self.mul(f, g);
237 q.data.resize(m + 1, ConstModInt::new(0));
238 q.data.reverse();
239
240 let d = b.len() - 1;
241 let mut r = self.sub(a, self.mul(b, q.clone()));
242 r.data.truncate(d);
243
244 r.shrink();
245 q.shrink();
246
247 (q, r)
248 }
249
250 pub fn differentiate(&self, a: Polynomial<P>) -> Polynomial<P> {
252 let mut a: Vec<_> = a.into();
253 let n = a.len();
254 if n > 0 {
255 for i in 0..n - 1 {
256 a[i] = a[i + 1] * ConstModInt::new(i as u32 + 1);
257 }
258 a.pop();
259 }
260 a.into()
261 }
262
263 pub fn integrate(&self, a: Polynomial<P>) -> Polynomial<P> {
265 let mut a: Vec<_> = a.into();
266 let n = a.len();
267 let mut invs = vec![ConstModInt::new(1); n + 1];
268 for i in 2..=n {
269 invs[i] = -invs[P as usize % i] * ConstModInt::new(P / i as u32);
270 }
271 a.push(ConstModInt::new(0));
272 for i in (0..n).rev() {
273 a[i + 1] = a[i] * invs[i + 1];
274 }
275 a[0] = ConstModInt::new(0);
276
277 a.into()
278 }
279
280 pub fn shift_higher(&self, a: Polynomial<P>, k: usize) -> Polynomial<P> {
284 let a: Vec<_> = a.into();
285 let n = a.len();
286 let mut ret = vec![ConstModInt::new(0); n];
287
288 ret[k..n].copy_from_slice(&a[..(n - k)]);
289
290 ret.into()
291 }
292
293 pub fn shift_lower(&self, a: Polynomial<P>, k: usize) -> Polynomial<P> {
295 let a: Vec<_> = a.into();
296 let n = a.len();
297 let mut ret = vec![ConstModInt::new(0); n];
298
299 for i in (0..n.saturating_sub(k)).rev() {
300 ret[i] = a[i + k];
301 }
302
303 ret.into()
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use crate::num::const_modint::ConstModIntBuilder;
310
311 use super::*;
312
313 const M: u32 = 998244353;
314
315 #[test]
316 fn test() {
317 let ff = ConstModIntBuilder::<M>;
318 let ntt = NTT::<M, 3>::new();
319 let po = PolynomialOperator::new(&ntt);
320
321 let a: Vec<_> = vec![5, 4, 3, 2, 1]
322 .into_iter()
323 .map(|x| ff.from_u64(x))
324 .collect();
325 let a = Polynomial::from(a);
326
327 let b: Vec<_> = vec![1, 2, 3, 4, 5]
328 .into_iter()
329 .map(|x| ff.from_u64(x))
330 .collect();
331 let b = Polynomial::from(b);
332
333 let (q, r) = po.divmod(a.clone(), b.clone());
334
335 let a_ = po.add(po.mul(q, b.clone()), r);
336 assert_eq!(a, a_);
337 }
338
339 #[test]
340 fn test_deg() {
341 let check = |a: Vec<usize>, d: Option<usize>| {
342 assert_eq!(Polynomial::<M>::from(a).deg(), d);
343 };
344
345 check(vec![1, 2, 3], Some(2));
346 check(vec![1, 2, 3, 0, 0, 0], Some(2));
347 check(vec![], None);
348 check(vec![0, 0, 0, 0], None);
349 }
350}