haar_lib/algo/
wildcard_pattern_matching.rs1use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12
13use crate::iter::collect::CollectVec;
14use crate::math::convolution::ntt::*;
15use crate::math::prime_mod::Prime;
16use crate::num::ff::*;
17
18pub fn wildcard_pattern_matching<T>(seq: Vec<T>, pat: Vec<T>, wildcard: T) -> Vec<bool>
21where
22 T: Hash + Eq + Copy,
23{
24 assert!(pat.len() <= seq.len());
25
26 let m = seq.len() - pat.len() + 1;
27 let n = (seq.len() + pat.len() - 1).next_power_of_two();
28 let ntt = NTT::<Prime<998244353>>::new();
29
30 let mut s = vec![0.into(); n];
31
32 for (i, x) in seq.into_iter().enumerate() {
33 if x != wildcard {
34 s[i] = hash(x).into();
35 }
36 }
37
38 let mut p = vec![0.into(); n];
39
40 for (i, x) in pat.into_iter().enumerate() {
41 let i = (n - i) % n;
42
43 if x != wildcard {
44 p[i] = hash(x).into()
45 }
46 }
47
48 let pr = p
49 .iter()
50 .enumerate()
51 .map(|(i, &x)| x * hash(i).into())
52 .collect_vec();
53
54 let mut s2 = s.iter().map(|&x| x * x).collect_vec();
55 let mut p1r = pr.clone();
56 ntt.ntt(&mut s2);
57 ntt.ntt(&mut p1r);
58 s2.iter_mut().zip(p1r).for_each(|(x, y)| *x *= y);
59
60 let mut s1 = s.clone();
61 let mut p2r = pr.into_iter().zip(p).map(|(x, y)| x * y).collect_vec();
62 ntt.ntt(&mut s1);
63 ntt.ntt(&mut p2r);
64 s1.iter_mut().zip(p2r).for_each(|(x, y)| *x *= y);
65
66 let mut ret = vec![0.into(); n];
67 ret.iter_mut().zip(s2).for_each(|(x, y)| *x += y);
68 ret.iter_mut().zip(s1).for_each(|(x, y)| *x -= y);
69 ntt.intt(&mut ret);
70 ret.into_iter().take(m).map(|x| x.value() == 0).collect()
71}
72
73fn hash<T: Hash>(x: T) -> u64 {
74 let mut s = DefaultHasher::new();
75 x.hash(&mut s);
76 s.finish()
77}