diff options
author | Xiaoxu Guo <shimo11370@proton.me> | 2024-06-15 09:32:12 +0800 |
---|---|---|
committer | Xiaoxu Guo <shimo11370@proton.me> | 2024-06-15 09:32:12 +0800 |
commit | 60eccba2a1c013f6831394429b0a521a9afb8db0 (patch) | |
tree | 0067cd6bbcc3fac5dc33149e55d7c21dc22b68aa | |
parent | 911d4d75651fbdbb430421b5b35ba7315edc6912 (diff) | |
download | shoka-60eccba2a1c013f6831394429b0a521a9afb8db0.tar.gz shoka-60eccba2a1c013f6831394429b0a521a9afb8db0.zip |
added capacity scaling min cost flow
-rw-r--r-- | cap_scaling_min_cost_flow.h | 114 |
1 files changed, 114 insertions, 0 deletions
diff --git a/cap_scaling_min_cost_flow.h b/cap_scaling_min_cost_flow.h new file mode 100644 index 0000000..2ec729c --- /dev/null +++ b/cap_scaling_min_cost_flow.h @@ -0,0 +1,114 @@ +#include "snippets/min_pow_of_two.h" +#include "snippets/min_pq.h" +#include "types/graph/adjacent_list_base.h" + +#include <algorithm> +#include <cassert> +#include <concepts> +#include <functional> + +template <std::integral CapT, typename CostT> +class CapScalingMinCostFlow + : public AdjacentListBase<std::tuple<int, CapT, CostT>> { + using Base = AdjacentListBase<std::tuple<int, CapT, CostT>>; + using Base::edges; + + using OutFn = std::function<void(CapT, CostT)>; + + void update(int u, int v, int i, CapT delta) { + excess[u] -= delta; + excess[v] += delta; + std::get<1>(edges[i]) -= delta; + std::get<1>(edges[i ^ 1]) += delta; + } + + bool augment(int s, int delta, const OutFn &output) { + std::ranges::fill(visited, false); + std::ranges::fill(dist, std::numeric_limits<CostT>::max()); + while (!pq.empty()) { + pq.pop(); + } + pq.emplace(dist[s] = 0, s); + while (!pq.empty()) { + auto [du, u] = pq.top(); + pq.pop(); + if (du == dist[u]) { + visited[u] = true; + if (excess[u] <= -delta) { + // augment along the path s->t + auto t = u; + output(delta, dist[t] + pi[s] - pi[t]); + for (int u = 0; u < n; u++) { + if (visited[u]) { + pi[u] = pi[u] - dist[u] + dist[t]; + } + } + for (int v = t; v != s;) { + auto i = pre[v]; + assert(~i); + auto u = std::get<0>(edges[i ^ 1]); + update(u, v, i, delta); + v = u; + } + return true; + } + for (int i = Base::head[u]; ~i; i = Base::next[i]) { + auto [v, c, w] = edges[i]; + auto rw = w - pi[u] + pi[v]; + if (c >= delta && dist[v] > dist[u] + rw) { + pre[v] = i; + pq.emplace(dist[v] = dist[u] + rw, v); + } + } + } + } + return false; + } + + int n; + CapT maxc = 1; + std::vector<int> visited, pre; + MinPQ<std::pair<CostT, int>> pq; + std::vector<CapT> excess; + std::vector<CostT> pi, dist; + +public: + explicit CapScalingMinCostFlow(int n_) + : Base{n_}, n{n_}, visited(n), pre(n), excess(n), pi(n), dist(n) {} + + void add_edge(int u, int v, CapT c, CostT w) { + Base::add(u, v, c, w); + Base::add(v, u, 0, -w); + maxc = std::max(maxc, c); + } + + CapT operator()(int source, int sink, const OutFn &output) { + auto delta = min_pow_of_two(maxc + 1) >> 1; + CapT flow{0}; + while (delta) { + for (int i = 0; i < static_cast<int>(edges.size()); i++) { + auto &[v, c, w] = edges[i]; + auto u = std::get<0>(edges[i ^ 1]); + if (c >= delta && w - pi[u] + pi[v] < 0) { + update(u, v, i, delta); + output(delta, w); + } + } + for (int s = 0; s < n; s++) { + while (excess[s] >= delta) { + assert(augment(s, delta, output)); + } + } + do { + flow += delta; + excess[source] += delta; + excess[sink] -= delta; + } while (augment(source, delta, output)); + flow -= delta; + excess[source] -= delta; + excess[sink] += delta; + delta >>= 1; + } + return flow; + } +}; |