haar_lib/tree/
rerooting.rs1use crate::tree::*;
4
5pub struct RerootingDP<'a, T, U, E> {
15 init: U,
16 up: Box<dyn 'a + Fn(T, &'a E) -> U>,
17 merge: Box<dyn 'a + Fn(U, U) -> U>,
18 apply: Box<dyn 'a + Fn(U, usize) -> T>,
19}
20
21impl<'a, T, U, E> RerootingDP<'a, T, U, E>
22where
23 E: TreeEdgeTrait,
24 T: Clone,
25 U: Clone,
26{
27 pub fn new(
29 init: U,
30 up: Box<impl 'a + Fn(T, &'a E) -> U>,
31 merge: Box<impl 'a + Fn(U, U) -> U>,
32 apply: Box<impl 'a + Fn(U, usize) -> T>,
33 ) -> Self {
34 Self {
35 init,
36 up,
37 merge,
38 apply,
39 }
40 }
41
42 pub fn run(&self, tree: &'a Tree<E>) -> Vec<T> {
44 let size = tree.len();
45 let mut dp = (0..size)
46 .map(|i| vec![None; tree.nodes[i].neighbors_size()])
47 .collect::<Vec<_>>();
48
49 self.rec1(tree, &mut dp, 0, None);
50 self.rec2(tree, &mut dp, 0, None, None);
51
52 tree.nodes
53 .iter()
54 .enumerate()
55 .map(|(i, nodes)| {
56 let acc = nodes
57 .neighbors()
58 .enumerate()
59 .filter_map(|(j, e)| dp[i][j].as_ref().map(|res| (self.up)(res.clone(), e)))
60 .fold(self.init.clone(), |x, y| (self.merge)(x, y));
61 (self.apply)(acc, i)
62 })
63 .collect()
64 }
65
66 fn rec1(
67 &self,
68 tree: &'a Tree<E>,
69 dp: &mut Vec<Vec<Option<T>>>,
70 cur: usize,
71 par: Option<usize>,
72 ) -> T {
73 let acc = tree.nodes[cur]
74 .neighbors()
75 .enumerate()
76 .filter(|(_, e)| par.is_none_or(|u| u != e.to()))
77 .map(|(i, e)| {
78 let res = self.rec1(tree, dp, e.to(), Some(cur));
79 dp[cur][i] = Some(res.clone());
80 (self.up)(res, e)
81 })
82 .fold(self.init.clone(), |x, y| (self.merge)(x, y));
83
84 (self.apply)(acc, cur)
85 }
86
87 fn rec2(
88 &self,
89 tree: &'a Tree<E>,
90 dp: &mut Vec<Vec<Option<T>>>,
91 cur: usize,
92 par: Option<usize>,
93 value: Option<T>,
94 ) {
95 let len = tree.nodes[cur].neighbors_size();
96
97 for (i, e) in tree.nodes[cur].neighbors().enumerate() {
98 if par.is_some_and(|u| u == e.to()) {
99 dp[cur][i] = value.clone();
100 }
101 }
102
103 let mut left = vec![self.init.clone(); len + 1];
104 let mut right = vec![self.init.clone(); len + 1];
105
106 if len > 1 {
107 for (i, e) in tree.nodes[cur].neighbors().take(len - 1).enumerate() {
108 left[i + 1] = if let Some(res) = dp[cur][i].clone() {
109 (self.merge)(left[i].clone(), (self.up)(res, e))
110 } else {
111 left[i].clone()
112 };
113 }
114
115 for (i, e) in tree.nodes[cur].neighbors().rev().take(len - 1).enumerate() {
116 let i = len - i - 1;
117 right[i - 1] = if let Some(res) = dp[cur][i].clone() {
118 (self.merge)(right[i].clone(), (self.up)(res, e))
119 } else {
120 right[i].clone()
121 };
122 }
123 }
124
125 for (i, e) in tree.nodes[cur].neighbors().enumerate() {
126 if par.is_some_and(|u| u == e.to()) {
127 continue;
128 }
129
130 self.rec2(
131 tree,
132 dp,
133 e.to(),
134 Some(cur),
135 Some((self.apply)(
136 (self.merge)(left[i].clone(), right[i].clone()),
137 cur,
138 )),
139 );
140 }
141 }
142}