haar_lib/ds/
dynamic_lazy_segtree.rs

1//! 動的遅延セグメント木
2use crate::algebra::action::Action;
3use std::ops::Range;
4use std::ptr;
5
6#[derive(Clone, Debug)]
7struct Node<A: Action> {
8    value: A::Output,
9    lazy: A::Lazy,
10    left: *mut Node<A>,
11    right: *mut Node<A>,
12}
13
14impl<A: Action> Node<A> {
15    fn new() -> Self {
16        Self {
17            value: A::fold_id(),
18            lazy: A::update_id(),
19            left: ptr::null_mut(),
20            right: ptr::null_mut(),
21        }
22    }
23}
24
25/// 動的遅延セグメント木
26#[derive(Clone, Debug)]
27pub struct DynamicLazySegtree<A: Action> {
28    root: *mut Node<A>,
29    to: usize,
30}
31
32impl<A: Action> Default for DynamicLazySegtree<A> {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38impl<A: Action> DynamicLazySegtree<A> {
39    /// `DynamicLazySegtree<A>`を生成する。
40    pub fn new() -> Self {
41        Self {
42            root: Box::into_raw(Box::new(Node::new())),
43            to: 1,
44        }
45    }
46}
47
48impl<A: Action> DynamicLazySegtree<A>
49where
50    A::Output: Clone + PartialEq,
51    A::Lazy: Clone + PartialEq,
52{
53    fn _propagate(&self, t: *mut Node<A>, from: usize, to: usize) {
54        assert!(!t.is_null());
55        let lazy = unsafe { (*t).lazy.clone() };
56
57        if lazy == A::update_id() {
58            return;
59        }
60        if to - from > 1 {
61            unsafe {
62                if (*t).left.is_null() {
63                    (*t).left = Box::into_raw(Box::new(Node::new()));
64                }
65                let left = (*t).left;
66                (*left).lazy = A::update((*left).lazy.clone(), lazy.clone());
67
68                if (*t).right.is_null() {
69                    (*t).right = Box::into_raw(Box::new(Node::new()));
70                }
71                let right = (*t).right;
72                (*right).lazy = A::update((*right).lazy.clone(), lazy.clone());
73            }
74        }
75        let len = to - from;
76        unsafe {
77            (*t).value = A::convert((*t).value.clone(), lazy, len);
78            (*t).lazy = A::update_id();
79        }
80    }
81
82    fn _update(
83        &self,
84        mut cur: *mut Node<A>,
85        from: usize,
86        to: usize,
87        s: usize,
88        t: usize,
89        value: A::Lazy,
90    ) -> *mut Node<A> {
91        if cur.is_null() {
92            cur = Box::into_raw(Box::new(Node::new()));
93        }
94
95        self._propagate(cur, from, to);
96
97        if to - from == 1 {
98            if s <= from && to <= t {
99                unsafe {
100                    (*cur).lazy = A::update((*cur).lazy.clone(), value);
101                }
102            }
103            self._propagate(cur, from, to);
104            return cur;
105        }
106
107        if to < s || t < from {
108            return cur;
109        }
110        if s <= from && to <= t {
111            unsafe {
112                (*cur).lazy = A::update((*cur).lazy.clone(), value);
113            }
114            self._propagate(cur, from, to);
115            return cur;
116        }
117
118        let mid = (from + to) / 2;
119        unsafe {
120            (*cur).left = self._update((*cur).left, from, mid, s, t, value.clone());
121            (*cur).right = self._update((*cur).right, mid, to, s, t, value);
122            (*cur).value = A::fold((*(*cur).left).value.clone(), (*(*cur).right).value.clone());
123        }
124        cur
125    }
126
127    /// 範囲`s..t`を`value`で更新する。
128    pub fn update(&mut self, Range { start: s, end: t }: Range<usize>, value: A::Lazy) {
129        loop {
130            if t <= self.to {
131                break;
132            }
133            self.to *= 2;
134
135            let mut new_root = Box::new(Node::new());
136            new_root.left = self.root;
137
138            self.root = Box::into_raw(new_root);
139        }
140
141        self._update(self.root, 0, self.to, s, t, value);
142    }
143
144    fn _fold(&self, cur: *mut Node<A>, from: usize, to: usize, s: usize, t: usize) -> A::Output {
145        if cur.is_null() {
146            return A::fold_id();
147        }
148
149        self._propagate(cur, from, to);
150        if to <= s || t <= from {
151            return A::fold_id();
152        }
153        if s <= from && to <= t {
154            return unsafe { (*cur).value.clone() };
155        }
156
157        let mid = (from + to) / 2;
158        let lv = self._fold(unsafe { (*cur).left }, from, mid, s, t);
159        let rv = self._fold(unsafe { (*cur).right }, mid, to, s, t);
160
161        A::fold(lv, rv)
162    }
163
164    /// 範囲`s..t`で計算を集約する。
165    pub fn fold(&mut self, Range { start: s, end: t }: Range<usize>) -> A::Output {
166        self._fold(self.root, 0, self.to, s, t)
167    }
168}