haar_lib/ds/
starry_sky_tree_count.rs

1//! 区間加算・個数総和付き区間Max(Min)
2
3use crate::misc::range::range_bounds_to_range;
4use crate::num::one_zero::Zero;
5use std::{
6    cmp::{max, min},
7    ops::{Add, RangeBounds, Sub},
8};
9
10/// 区間Max/Minを選択する。
11#[derive(Copy, Clone)]
12pub enum Mode {
13    /// 区間Max
14    Max,
15    /// 区間Min
16    Min,
17}
18
19impl Mode {
20    fn op<T: Ord>(self, a: T, b: T) -> T {
21        match self {
22            Mode::Max => max(a, b),
23            Mode::Min => min(a, b),
24        }
25    }
26}
27
28/// 区間加算・個数総和付き区間Max(Min)ができるデータ構造。
29pub struct StarrySkyTreeCount<T> {
30    size: usize,
31    original_size: usize,
32    data: Vec<T>,
33    count: Vec<u64>,
34    mode: Mode,
35}
36
37impl<T> StarrySkyTreeCount<T>
38where
39    T: Add<Output = T> + Sub<Output = T> + Ord + Copy + Zero,
40{
41    /// **Time complexity** $O(n)$
42    pub fn new(coeffs: Vec<u64>, mode: Mode) -> Self {
43        let n = coeffs.len();
44        let size = n.next_power_of_two() * 2;
45        let zero = T::zero();
46
47        let mut count = vec![0; size];
48        for (i, &x) in coeffs.iter().enumerate() {
49            count[size / 2 + i] = x;
50        }
51        for i in (1..size / 2).rev() {
52            count[i] = count[i << 1] + count[(i << 1) | 1];
53        }
54
55        Self {
56            size,
57            original_size: n,
58            data: vec![zero; size],
59            count,
60            mode,
61        }
62    }
63
64    fn rec(&self, s: usize, t: usize, i: usize, l: usize, r: usize, value: T) -> Option<(T, u64)> {
65        if r <= s || t <= l {
66            return None;
67        }
68        if s <= l && r <= t {
69            return Some((value + self.data[i], self.count[i]));
70        }
71
72        let m = (l + r) / 2;
73        let a = self.rec(s, t, i << 1, l, m, value + self.data[i]);
74        let b = self.rec(s, t, (i << 1) | 1, m, r, value + self.data[i]);
75
76        match (a, b) {
77            (None, _) => b,
78            (_, None) => a,
79            (Some((a, ca)), Some((b, cb))) => {
80                let t = self.mode.op(a, b);
81                if a == b {
82                    Some((a, ca + cb))
83                } else if a == t {
84                    Some((a, ca))
85                } else {
86                    Some((b, cb))
87                }
88            }
89        }
90    }
91
92    /// **Time complexity** $O(\log n)$
93    pub fn fold(&self, range: impl RangeBounds<usize>) -> Option<(T, u64)> {
94        let (l, r) = range_bounds_to_range(range, 0, self.original_size);
95        self.rec(l, r, 1, 0, self.size / 2, T::zero())
96    }
97
98    fn bottom_up(&mut self, mut i: usize) {
99        if i > self.size {
100            return;
101        }
102
103        while i >= 1 {
104            if i < self.size / 2 {
105                let d = self.mode.op(self.data[i << 1], self.data[(i << 1) | 1]);
106
107                self.data[i << 1] = self.data[i << 1] - d;
108                self.data[(i << 1) | 1] = self.data[(i << 1) | 1] - d;
109                self.data[i] = self.data[i] + d;
110
111                let l = self.data[i << 1];
112                let r = self.data[(i << 1) | 1];
113                let t = self.mode.op(l, r);
114
115                if l == r {
116                    self.count[i] = self.count[i << 1] + self.count[(i << 1) | 1];
117                } else if l == t {
118                    self.count[i] = self.count[i << 1];
119                } else {
120                    self.count[i] = self.count[(i << 1) | 1];
121                }
122            }
123
124            i >>= 1;
125        }
126    }
127
128    /// **Time complexity** $O(\log n)$
129    pub fn update(&mut self, range: impl RangeBounds<usize>, value: T) {
130        let (l, r) = range_bounds_to_range(range, 0, self.original_size);
131
132        let hsize = self.size / 2;
133        let mut ll = l + hsize;
134        let mut rr = r + hsize;
135
136        while ll < rr {
137            if (rr & 1) != 0 {
138                rr -= 1;
139                self.data[rr] = self.data[rr] + value;
140            }
141            if (ll & 1) != 0 {
142                self.data[ll] = self.data[ll] + value;
143                ll += 1;
144            }
145            ll >>= 1;
146            rr >>= 1;
147        }
148
149        self.bottom_up(l + hsize);
150        self.bottom_up(r + hsize);
151    }
152}