haar_lib/algo/
wildcard_pattern_matching.rs

1//! ワイルドカードパターンマッチング
2//!
3//! # References
4//! - <https://qiita.com/MatsuTaku/items/cd5581fab97d7e74a7b3>
5//! - <https://noshi91.hatenablog.com/entry/2024/05/26/060854>
6//!
7//! # Problems
8//! - <https://judge.yosupo.jp/problem/wildcard_pattern_matching>
9
10use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12
13use crate::iter::collect::CollectVec;
14use crate::math::convolution::ntt::*;
15use crate::math::prime_mod::Prime;
16use crate::num::ff::*;
17
18/// `seq`の`|pat|`長の各連続部分列が`pat`と一致するかを判定する。
19/// `wildcard`はワイルドカードとして扱う。
20pub fn wildcard_pattern_matching<T>(seq: Vec<T>, pat: Vec<T>, wildcard: T) -> Vec<bool>
21where
22    T: Hash + Eq + Copy,
23{
24    assert!(pat.len() <= seq.len());
25
26    let m = seq.len() - pat.len() + 1;
27    let n = (seq.len() + pat.len() - 1).next_power_of_two();
28    let ntt = NTT::<Prime<998244353>>::new();
29
30    let mut s = vec![0.into(); n];
31
32    for (i, x) in seq.into_iter().enumerate() {
33        if x != wildcard {
34            s[i] = hash(x).into();
35        }
36    }
37
38    let mut p = vec![0.into(); n];
39
40    for (i, x) in pat.into_iter().enumerate() {
41        let i = (n - i) % n;
42
43        if x != wildcard {
44            p[i] = hash(x).into()
45        }
46    }
47
48    let pr = p
49        .iter()
50        .enumerate()
51        .map(|(i, &x)| x * hash(i).into())
52        .collect_vec();
53
54    let mut s2 = s.iter().map(|&x| x * x).collect_vec();
55    let mut p1r = pr.clone();
56    ntt.ntt(&mut s2);
57    ntt.ntt(&mut p1r);
58    s2.iter_mut().zip(p1r).for_each(|(x, y)| *x *= y);
59
60    let mut s1 = s.clone();
61    let mut p2r = pr.into_iter().zip(p).map(|(x, y)| x * y).collect_vec();
62    ntt.ntt(&mut s1);
63    ntt.ntt(&mut p2r);
64    s1.iter_mut().zip(p2r).for_each(|(x, y)| *x *= y);
65
66    let mut ret = vec![0.into(); n];
67    ret.iter_mut().zip(s2).for_each(|(x, y)| *x += y);
68    ret.iter_mut().zip(s1).for_each(|(x, y)| *x -= y);
69    ntt.intt(&mut ret);
70    ret.into_iter().take(m).map(|x| x.value() == 0).collect()
71}
72
73fn hash<T: Hash>(x: T) -> u64 {
74    let mut s = DefaultHasher::new();
75    x.hash(&mut s);
76    s.finish()
77}