haar_lib/flow/
assignment.rs1#[derive(Clone, Debug)]
7struct Edge {
8 to: usize,
9 rev: usize,
10 cap: bool,
11 cost: i64,
12 is_rev: bool,
13}
14
15fn add_edge(edges: &mut [Vec<Edge>], u: usize, v: usize, cost: i64) {
16 let rev = edges[v].len();
17 edges[u].push(Edge {
18 to: v,
19 rev,
20 cap: true,
21 cost,
22 is_rev: false,
23 });
24 let rev = edges[u].len() - 1;
25 edges[v].push(Edge {
26 to: u,
27 rev,
28 cap: false,
29 cost: -cost,
30 is_rev: true,
31 });
32}
33
34pub fn assignment(a: Vec<Vec<i64>>) -> (i64, Vec<usize>) {
36 let n = a.len();
37 assert!(a.iter().all(|v| v.len() == n));
38
39 let size = n * 2 + 1;
40 let mut edges = vec![vec![]; n * 2 + 1];
41
42 for i in 0..n {
43 for j in 0..n {
44 add_edge(&mut edges, i, n + j, a[i][j]);
45 }
46 }
47
48 let sink = 2 * n;
49
50 for i in 0..n {
51 add_edge(&mut edges, n + i, sink, 0);
52 }
53
54 let mut min_cost = 0;
55 let mut h = vec![0; size];
56
57 for src in 0..n {
58 let mut prev = vec![(0, 0); size];
59 let mut cost = vec![None; size];
60 let mut pq = vec![None; size];
61
62 cost[src] = Some(0);
63 pq[src] = Some(0);
64
65 loop {
66 let Some((c, v)) = pq
67 .iter()
68 .enumerate()
69 .filter_map(|(i, c)| c.map(|x| (x, i)))
70 .min()
71 else {
72 break;
73 };
74
75 pq[v] = None;
76
77 let h_v = h[v];
78
79 for (i, e) in edges[v].iter().enumerate() {
80 if e.cap
81 && (cost[e.to].is_none() || cost[e.to].unwrap() + h[e.to] > c + h_v + e.cost)
82 {
83 cost[e.to] = Some(c + e.cost + h_v - h[e.to]);
84 prev[e.to] = (v, i);
85 pq[e.to] = pq[e.to].map(|x| x.min(cost[e.to].unwrap())).or(cost[e.to]);
86 }
87 }
88 }
89
90 for i in 0..size {
91 if let Some(x) = cost[i] {
92 h[i] += x;
93 }
94 }
95
96 min_cost += h[sink];
97
98 let mut cur = sink;
99 while cur != src {
100 let e = &mut edges[prev[cur].0][prev[cur].1];
101 e.cap ^= true;
102 let rev = e.rev;
103 edges[cur][rev].cap ^= true;
104 cur = prev[cur].0;
105 }
106 }
107
108 let assignment = (0..n)
109 .map(|i| {
110 let k = edges[i].iter().find(|e| !e.is_rev && !e.cap).unwrap().to;
111 k - n
112 })
113 .collect();
114
115 (min_cost, assignment)
116}