haar_lib/algo/
bsearch.rs

1//! 単調増加な判定関数上の二分探索
2use std::ops::{Add, Div, Sub};
3
4/// [`bsearch_ng_ok`]、[`bsearch_ok_ng`]の返り値
5#[derive(Clone, Copy, Debug)]
6pub enum SearchResult<T> {
7    /// `ng`以下で条件を満たさず、`ok`以上で条件を満たす。
8    NgOk {
9        /// 条件を満たさない最大値
10        ng: T,
11        /// 条件を満たす最小値
12        ok: T,
13    },
14    /// `ok`以下で条件を満たし、`ng`以上で条件を満たさない。
15    OkNg {
16        /// 条件を満たす最大値
17        ok: T,
18        /// 条件を満たさない最小値
19        ng: T,
20    },
21    /// 全体で条件を満たす。
22    AllOk,
23    /// 全体で条件を満たさない。
24    AllNg,
25}
26
27/// 二分探索
28///
29/// `f`は、`lower..=upper`の範囲で、ある値を境界にそれ未満では常に`false`、それ以上では常に`true`となる関数
30///
31/// **Time complexity** $O(\log n)$
32pub fn bsearch_ng_ok<
33    T: Copy + PartialOrd + Add<Output = T> + Sub<Output = T> + Div<Output = T> + From<u8>,
34>(
35    mut lower: T,
36    mut upper: T,
37    f: impl Fn(T) -> bool,
38) -> SearchResult<T> {
39    assert!(lower < upper);
40
41    if f(lower) {
42        // all ok
43        return SearchResult::AllOk;
44    } else if !f(upper) {
45        // all ng
46        return SearchResult::AllNg;
47    }
48
49    while upper - lower > T::from(1) {
50        let mid = (lower + upper) / T::from(2);
51
52        if f(mid) {
53            upper = mid;
54        } else {
55            lower = mid
56        }
57    }
58
59    SearchResult::NgOk {
60        ng: lower,
61        ok: upper,
62    }
63}
64
65/// 二分探索
66///
67/// `f`は、`lower..=upper`の範囲で、ある値を境界にそれ未満では常に`true`、それ以上では常に`false`となる関数
68///
69/// **Time complexity** $O(\log n)$
70pub fn bsearch_ok_ng<
71    T: Copy + PartialOrd + Add<Output = T> + Sub<Output = T> + Div<Output = T> + From<u8>,
72>(
73    lower: T,
74    upper: T,
75    f: impl Fn(T) -> bool,
76) -> SearchResult<T> {
77    assert!(lower < upper);
78
79    match bsearch_ng_ok(lower, upper, |x| !f(x)) {
80        SearchResult::AllNg => SearchResult::AllOk,
81        SearchResult::AllOk => SearchResult::AllNg,
82        SearchResult::NgOk { ng, ok } => SearchResult::OkNg { ok: ng, ng: ok },
83        _ => unreachable!(),
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use std::iter::repeat_n;
91
92    #[test]
93    fn test() {
94        let n = 100;
95
96        for k in 0..=n {
97            let a = repeat_n(0, k)
98                .chain(repeat_n(1, n - k))
99                .collect::<Vec<u64>>();
100
101            let check = |i| a[i] > 0;
102
103            let res = bsearch_ng_ok(0, n - 1, check);
104
105            match res {
106                SearchResult::NgOk { ng, ok } => {
107                    assert!(!check(ng));
108                    assert!(check(ok));
109                    assert_eq!(ng + 1, ok);
110                }
111                SearchResult::AllOk => {
112                    assert!((0..n).all(check));
113                }
114                SearchResult::AllNg => {
115                    assert!((0..n).all(|i| !check(i)));
116                }
117                _ => {}
118            }
119        }
120    }
121}