haar_lib/algo/
wildcard_pattern_matching.rs

1//! ワイルドカードパターンマッチング
2//!
3//! # References
4//! - <https://qiita.com/MatsuTaku/items/cd5581fab97d7e74a7b3>
5//! - <https://noshi91.hatenablog.com/entry/2024/05/26/060854>
6//!
7//! # Problems
8//! - <https://judge.yosupo.jp/problem/wildcard_pattern_matching>
9
10use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12
13use crate::iter::collect::CollectVec;
14use crate::math::ntt::*;
15use crate::num::ff::FFElem;
16
17/// `seq`の`|pat|`長の各連続部分列が`pat`と一致するかを判定する。
18/// `wildcard`はワイルドカードとして扱う。
19pub fn wildcard_pattern_matching<T>(seq: Vec<T>, pat: Vec<T>, wildcard: T) -> Vec<bool>
20where
21    T: Hash + Eq + Copy,
22{
23    assert!(pat.len() <= seq.len());
24
25    let m = seq.len() - pat.len() + 1;
26    let n = (seq.len() + pat.len() - 1).next_power_of_two();
27    let ntt = NTT998244353::new();
28
29    let mut s = vec![0.into(); n];
30
31    for (i, x) in seq.into_iter().enumerate() {
32        if x != wildcard {
33            s[i] = hash(x).into();
34        }
35    }
36
37    let mut p = vec![0.into(); n];
38
39    for (i, x) in pat.into_iter().enumerate() {
40        let i = (n - i) % n;
41
42        if x != wildcard {
43            p[i] = hash(x).into()
44        }
45    }
46
47    let pr = p
48        .iter()
49        .enumerate()
50        .map(|(i, &x)| x * hash(i).into())
51        .collect_vec();
52
53    let mut s2 = s.iter().map(|&x| x * x).collect_vec();
54    let mut p1r = pr.clone();
55    ntt.ntt(&mut s2);
56    ntt.ntt(&mut p1r);
57    s2.iter_mut().zip(p1r).for_each(|(x, y)| *x *= y);
58
59    let mut s1 = s.clone();
60    let mut p2r = pr.into_iter().zip(p).map(|(x, y)| x * y).collect_vec();
61    ntt.ntt(&mut s1);
62    ntt.ntt(&mut p2r);
63    s1.iter_mut().zip(p2r).for_each(|(x, y)| *x *= y);
64
65    let mut ret = vec![0.into(); n];
66    ret.iter_mut().zip(s2).for_each(|(x, y)| *x += y);
67    ret.iter_mut().zip(s1).for_each(|(x, y)| *x -= y);
68    ntt.intt(&mut ret);
69    ret.into_iter().take(m).map(|x| x.value() == 0).collect()
70}
71
72fn hash<T: Hash>(x: T) -> u64 {
73    let mut s = DefaultHasher::new();
74    x.hash(&mut s);
75    s.finish()
76}