haar_lib/tree/
rerooting.rs

1//! 全方位木DP
2
3use crate::tree::*;
4
5/// 全方位木DP
6///
7/// # References
8/// - <https://null-mn.hatenablog.com/entry/2020/04/14/124151>
9///
10/// # Problems
11/// - [EDPC V - Subtree](https://atcoder.jp/contests/dp/submissions/57560435)
12/// - <https://atcoder.jp/contests/abc160/tasks/abc160_f>
13/// - <https://judge.yosupo.jp/problem/tree_path_composite_sum>
14pub struct RerootingDP<'a, T, U, E> {
15    init: U,
16    up: Box<dyn 'a + Fn(T, &'a E) -> U>,
17    merge: Box<dyn 'a + Fn(U, U) -> U>,
18    apply: Box<dyn 'a + Fn(U, usize) -> T>,
19}
20
21impl<'a, T, U, E> RerootingDP<'a, T, U, E>
22where
23    E: TreeEdgeTrait,
24    T: Clone,
25    U: Clone,
26{
27    /// `RerootingDP`を構築する。
28    pub fn new(
29        init: U,
30        up: Box<impl 'a + Fn(T, &'a E) -> U>,
31        merge: Box<impl 'a + Fn(U, U) -> U>,
32        apply: Box<impl 'a + Fn(U, usize) -> T>,
33    ) -> Self {
34        Self {
35            init,
36            up,
37            merge,
38            apply,
39        }
40    }
41
42    /// `tree`上で、全方位DPを実行する。
43    pub fn run(&self, tree: &'a Tree<E>) -> Vec<T> {
44        let size = tree.len();
45        let mut dp = (0..size)
46            .map(|i| vec![None; tree.nodes[i].neighbors_size()])
47            .collect::<Vec<_>>();
48
49        self.rec1(tree, &mut dp, 0, None);
50        self.rec2(tree, &mut dp, 0, None, None);
51
52        tree.nodes
53            .iter()
54            .enumerate()
55            .map(|(i, nodes)| {
56                let acc = nodes
57                    .neighbors()
58                    .enumerate()
59                    .filter_map(|(j, e)| dp[i][j].as_ref().map(|res| (self.up)(res.clone(), e)))
60                    .fold(self.init.clone(), |x, y| (self.merge)(x, y));
61                (self.apply)(acc, i)
62            })
63            .collect()
64    }
65
66    fn rec1(
67        &self,
68        tree: &'a Tree<E>,
69        dp: &mut Vec<Vec<Option<T>>>,
70        cur: usize,
71        par: Option<usize>,
72    ) -> T {
73        let acc = tree.nodes[cur]
74            .neighbors()
75            .enumerate()
76            .filter(|(_, e)| par.is_none_or(|u| u != e.to()))
77            .map(|(i, e)| {
78                let res = self.rec1(tree, dp, e.to(), Some(cur));
79                dp[cur][i] = Some(res.clone());
80                (self.up)(res, e)
81            })
82            .fold(self.init.clone(), |x, y| (self.merge)(x, y));
83
84        (self.apply)(acc, cur)
85    }
86
87    fn rec2(
88        &self,
89        tree: &'a Tree<E>,
90        dp: &mut Vec<Vec<Option<T>>>,
91        cur: usize,
92        par: Option<usize>,
93        value: Option<T>,
94    ) {
95        let len = tree.nodes[cur].neighbors_size();
96
97        for (i, e) in tree.nodes[cur].neighbors().enumerate() {
98            if par.is_some_and(|u| u == e.to()) {
99                dp[cur][i] = value.clone();
100            }
101        }
102
103        let mut left = vec![self.init.clone(); len + 1];
104        let mut right = vec![self.init.clone(); len + 1];
105
106        if len > 1 {
107            for (i, e) in tree.nodes[cur].neighbors().take(len - 1).enumerate() {
108                left[i + 1] = if let Some(res) = dp[cur][i].clone() {
109                    (self.merge)(left[i].clone(), (self.up)(res, e))
110                } else {
111                    left[i].clone()
112                };
113            }
114
115            for (i, e) in tree.nodes[cur].neighbors().rev().take(len - 1).enumerate() {
116                let i = len - i - 1;
117                right[i - 1] = if let Some(res) = dp[cur][i].clone() {
118                    (self.merge)(right[i].clone(), (self.up)(res, e))
119                } else {
120                    right[i].clone()
121                };
122            }
123        }
124
125        for (i, e) in tree.nodes[cur].neighbors().enumerate() {
126            if par.is_some_and(|u| u == e.to()) {
127                continue;
128            }
129
130            self.rec2(
131                tree,
132                dp,
133                e.to(),
134                Some(cur),
135                Some((self.apply)(
136                    (self.merge)(left[i].clone(), right[i].clone()),
137                    cur,
138                )),
139            );
140        }
141    }
142}