haar_lib/ds/
merge_sort_tree.rs1use crate::algo::{bsearch_slice::BinarySearch, merge::inplace_merge};
9use crate::misc::range::range_bounds_to_range;
10use crate::num::one_zero::Zero;
11use std::ops::{Add, AddAssign, RangeBounds};
12
13pub struct MergeSortTree<T> {
15 data: Vec<Vec<T>>,
16 accum: Vec<Vec<T>>,
17 size: usize,
18 original_size: usize,
19}
20
21impl<T> MergeSortTree<T>
22where
23 T: Copy + Clone + Zero + Add<Output = T> + AddAssign + PartialOrd + Ord,
24{
25 pub fn new(mut a: Vec<T>) -> Self {
29 let n = a.len();
30 let size = n.next_power_of_two() * 2;
31
32 let mut this = Self {
33 data: vec![vec![]; size],
34 accum: vec![vec![]; size],
35 size,
36 original_size: n,
37 };
38
39 this._init(1, &mut a, 0, size / 2);
40
41 this
42 }
43
44 fn _init(&mut self, i: usize, a: &mut [T], l: usize, r: usize) {
45 if a.len() <= l {
46 return;
47 }
48
49 if r - l == 1 {
50 self.data[i] = a[l..r].to_vec();
51 } else {
52 let mid = (l + r) / 2;
53 self._init(i << 1, a, l, mid);
54 self._init((i << 1) | 1, a, mid, r);
55
56 if a.len() <= mid {
57 self.data[i] = a[l..].to_vec();
58 } else {
59 let k = mid - l;
60 let end = r.min(a.len());
61 inplace_merge(&mut a[l..end], k);
62
63 self.data[i] = a[l..end].to_vec();
64 }
65 }
66 self.accum[i] = Self::_accum(&self.data[i]);
67 }
68
69 fn _accum(a: &[T]) -> Vec<T> {
70 let mut ret = vec![T::zero(); a.len() + 1];
71 for (i, x) in a.iter().enumerate() {
72 ret[i + 1] = ret[i] + *x;
73 }
74 ret
75 }
76
77 pub fn sum_count_le(&self, range: impl RangeBounds<usize>, ub: T) -> (T, usize) {
81 let (l, r) = range_bounds_to_range(range, 0, self.original_size);
82 assert!(l <= r && r <= self.original_size);
83
84 let mut l = l + self.size / 2;
85 let mut r = r + self.size / 2;
86 let mut sum = T::zero();
87 let mut count = 0;
88
89 while l < r {
90 if r & 1 == 1 {
91 r -= 1;
92 let i = self.data[r].upper_bound(&ub);
93 count += i;
94 sum += self.accum[r][i];
95 }
96 if l & 1 == 1 {
97 let i = self.data[l].upper_bound(&ub);
98 count += i;
99 sum += self.accum[l][i];
100 l += 1;
101 }
102 r >>= 1;
103 l >>= 1;
104 }
105
106 (sum, count)
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use my_testtools::*;
114 use rand::Rng;
115 use std::ops::Range;
116
117 #[test]
118 fn test() {
119 let mut rng = rand::thread_rng();
120
121 let n = 300;
122 let t = 300;
123
124 let a = (0..n).map(|_| rng.gen::<u64>() % 10000).collect::<Vec<_>>();
125 let s = MergeSortTree::new(a.clone());
126
127 for _ in 0..t {
128 let Range { start: l, end: r } = rand_range(&mut rng, 0..n);
129 let x = rng.gen::<u64>() % 10000;
130
131 let (res_sum, res_count) = s.sum_count_le(l..r, x);
132 let ans_sum = a[l..r].iter().filter(|&&y| y <= x).sum::<u64>();
133 let ans_count = a[l..r].iter().filter(|&&y| y <= x).count();
134
135 assert_eq!(res_sum, ans_sum);
136 assert_eq!(res_count, ans_count);
137 }
138 }
139}