haar_lib/algo/
wildcard_pattern_matching.rs1use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12
13use crate::iter::collect::CollectVec;
14use crate::math::ntt::*;
15use crate::num::ff::FFElem;
16
17pub fn wildcard_pattern_matching<T>(seq: Vec<T>, pat: Vec<T>, wildcard: T) -> Vec<bool>
20where
21 T: Hash + Eq + Copy,
22{
23 assert!(pat.len() <= seq.len());
24
25 let m = seq.len() - pat.len() + 1;
26 let n = (seq.len() + pat.len() - 1).next_power_of_two();
27 let ntt = NTT998244353::new();
28
29 let mut s = vec![0.into(); n];
30
31 for (i, x) in seq.into_iter().enumerate() {
32 if x != wildcard {
33 s[i] = hash(x).into();
34 }
35 }
36
37 let mut p = vec![0.into(); n];
38
39 for (i, x) in pat.into_iter().enumerate() {
40 let i = (n - i) % n;
41
42 if x != wildcard {
43 p[i] = hash(x).into()
44 }
45 }
46
47 let pr = p
48 .iter()
49 .enumerate()
50 .map(|(i, &x)| x * hash(i).into())
51 .collect_vec();
52
53 let mut s2 = s.iter().map(|&x| x * x).collect_vec();
54 let mut p1r = pr.clone();
55 ntt.ntt(&mut s2);
56 ntt.ntt(&mut p1r);
57 s2.iter_mut().zip(p1r).for_each(|(x, y)| *x *= y);
58
59 let mut s1 = s.clone();
60 let mut p2r = pr.into_iter().zip(p).map(|(x, y)| x * y).collect_vec();
61 ntt.ntt(&mut s1);
62 ntt.ntt(&mut p2r);
63 s1.iter_mut().zip(p2r).for_each(|(x, y)| *x *= y);
64
65 let mut ret = vec![0.into(); n];
66 ret.iter_mut().zip(s2).for_each(|(x, y)| *x += y);
67 ret.iter_mut().zip(s1).for_each(|(x, y)| *x -= y);
68 ntt.intt(&mut ret);
69 ret.into_iter().take(m).map(|x| x.value() == 0).collect()
70}
71
72fn hash<T: Hash>(x: T) -> u64 {
73 let mut s = DefaultHasher::new();
74 x.hash(&mut s);
75 s.finish()
76}