haar_lib/ds/
merge_sort_tree.rs

1//! Merge-sort Tree
2//!
3//! # Problems
4//! - <https://atcoder.jp/contests/abc339/tasks/abc339_g>
5//! - <https://atcoder.jp/contests/abc351/tasks/abc351_f>
6//! - <https://judge.yosupo.jp/problem/static_range_sum_with_upper_bound>
7
8use crate::algo::{bsearch_slice::BinarySearch, merge::inplace_merge};
9use crate::misc::range::range_bounds_to_range;
10use crate::num::one_zero::Zero;
11use std::ops::{Add, AddAssign, RangeBounds};
12
13/// Merge-sort Tree
14pub struct MergeSortTree<T> {
15    data: Vec<Vec<T>>,
16    accum: Vec<Vec<T>>,
17    size: usize,
18    original_size: usize,
19}
20
21impl<T> MergeSortTree<T>
22where
23    T: Copy + Clone + Zero + Add<Output = T> + AddAssign + PartialOrd + Ord,
24{
25    /// **Time complexity** $O(n \log n)$
26    ///
27    /// **Space complexity** $O(n \log n)$
28    pub fn new(mut a: Vec<T>) -> Self {
29        let n = a.len();
30        let size = n.next_power_of_two() * 2;
31
32        let mut this = Self {
33            data: vec![vec![]; size],
34            accum: vec![vec![]; size],
35            size,
36            original_size: n,
37        };
38
39        this._init(1, &mut a, 0, size / 2);
40
41        this
42    }
43
44    fn _init(&mut self, i: usize, a: &mut [T], l: usize, r: usize) {
45        if a.len() <= l {
46            return;
47        }
48
49        if r - l == 1 {
50            self.data[i] = a[l..r].to_vec();
51        } else {
52            let mid = (l + r) / 2;
53            self._init(i << 1, a, l, mid);
54            self._init((i << 1) | 1, a, mid, r);
55
56            if a.len() <= mid {
57                self.data[i] = a[l..].to_vec();
58            } else {
59                let k = mid - l;
60                let end = r.min(a.len());
61                inplace_merge(&mut a[l..end], k);
62
63                self.data[i] = a[l..end].to_vec();
64            }
65        }
66        self.accum[i] = Self::_accum(&self.data[i]);
67    }
68
69    fn _accum(a: &[T]) -> Vec<T> {
70        let mut ret = vec![T::zero(); a.len() + 1];
71        for (i, x) in a.iter().enumerate() {
72            ret[i + 1] = ret[i] + *x;
73        }
74        ret
75    }
76
77    /// `ub`以下の総和を求める
78    ///
79    /// **Time complexity** $O((\log N) ^ 2)$
80    pub fn sum_count_le(&self, range: impl RangeBounds<usize>, ub: T) -> (T, usize) {
81        let (l, r) = range_bounds_to_range(range, 0, self.original_size);
82        assert!(l <= r && r <= self.original_size);
83
84        let mut l = l + self.size / 2;
85        let mut r = r + self.size / 2;
86        let mut sum = T::zero();
87        let mut count = 0;
88
89        while l < r {
90            if r & 1 == 1 {
91                r -= 1;
92                let i = self.data[r].upper_bound(&ub);
93                count += i;
94                sum += self.accum[r][i];
95            }
96            if l & 1 == 1 {
97                let i = self.data[l].upper_bound(&ub);
98                count += i;
99                sum += self.accum[l][i];
100                l += 1;
101            }
102            r >>= 1;
103            l >>= 1;
104        }
105
106        (sum, count)
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use my_testtools::*;
114    use rand::Rng;
115    use std::ops::Range;
116
117    #[test]
118    fn test() {
119        let mut rng = rand::thread_rng();
120
121        let n = 300;
122        let t = 300;
123
124        let a = (0..n).map(|_| rng.gen::<u64>() % 10000).collect::<Vec<_>>();
125        let s = MergeSortTree::new(a.clone());
126
127        for _ in 0..t {
128            let Range { start: l, end: r } = rand_range(&mut rng, 0..n);
129            let x = rng.gen::<u64>() % 10000;
130
131            let (res_sum, res_count) = s.sum_count_le(l..r, x);
132            let ans_sum = a[l..r].iter().filter(|&&y| y <= x).sum::<u64>();
133            let ans_count = a[l..r].iter().filter(|&&y| y <= x).count();
134
135            assert_eq!(res_sum, ans_sum);
136            assert_eq!(res_count, ans_count);
137        }
138    }
139}