aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiaoxu Guo <ftiasch0@gmail.com>2024-04-12 16:00:54 +0800
committerXiaoxu Guo <ftiasch0@gmail.com>2024-04-12 16:00:54 +0800
commitf88e38b0c3e6397bc313067f4572cf40bed90b69 (patch)
tree6dccea699452098967da5f151021400337f7fe67
parent8c1f654178f03334b9201a09ca66e3e2e93159bc (diff)
downloadshoka-f88e38b0c3e6397bc313067f4572cf40bed90b69.tar.gz
shoka-f88e38b0c3e6397bc313067f4572cf40bed90b69.zip
fixed smawk
-rw-r--r--smawk.h77
-rw-r--r--test/smawk.hpp18
2 files changed, 85 insertions, 10 deletions
diff --git a/smawk.h b/smawk.h
new file mode 100644
index 0000000..81ecacc
--- /dev/null
+++ b/smawk.h
@@ -0,0 +1,77 @@
+#include <concepts>
+#include <numeric>
+#include <ranges>
+#include <span>
+#include <type_traits>
+#include <vector>
+
+#include "debug.h"
+
+template <typename A>
+concept IsTM = requires(A m) {
+ {
+ m(std::declval<int>(), std::declval<int>())
+ } -> std::convertible_to<typename A::E>;
+ {
+ std::declval<typename A::E>() > std::declval<typename A::E>()
+ } -> std::convertible_to<bool>;
+};
+
+template <IsTM A> struct SMAWK {
+ explicit SMAWK(int n_, int m_, const A &a_)
+ : n{n_}, m{m_}, a{a_}, row_min(n), cols(m + 2 * n) {
+ stack.reserve(n);
+ std::iota(cols.begin(), cols.begin() + m, 0);
+ recur(0, 0, m);
+ }
+
+ void recur(int k, int begin, int end) {
+ if (n < (2 << k)) {
+ auto r = (1 << k) - 1;
+ row_min[r] = get(r + 1, cols[begin]);
+ for (int i = begin + 1; i < end; i++) {
+ check_min(row_min[r], get(r + 1, cols[i]));
+ }
+ } else {
+ stack.clear();
+ for (int i = begin; i < end; i++) {
+ auto c = cols[i];
+ while (!stack.empty() && stack.back() > get((stack.size() << k), c)) {
+ stack.pop_back();
+ }
+ auto r1 = (stack.size() + 1) << k;
+ if (r1 <= n) {
+ stack.push_back(get(r1, c));
+ }
+ }
+ for (int i = 0; i < stack.size(); i++) {
+ cols[end + i] = stack[i].second;
+ }
+ begin = end, end += stack.size();
+ recur(k + 1, begin, end);
+ auto offset = 1 << k;
+ for (int r = offset - 1, p = begin; r < n; r += offset << 1) {
+ auto high = r + offset < n ? row_min[r + offset].second + 1 : m;
+ row_min[r] = get(r + 1, cols[p]);
+ while (p + 1 < end && cols[p + 1] < high) {
+ check_min(row_min[r], get(r + 1, cols[++p]));
+ }
+ }
+ }
+ }
+
+ using TP = std::pair<typename A::E, int>;
+
+ TP get(int x1, int y) { return {a(x1 - 1, y), y}; }
+
+ static void check_min(TP &x, TP a) {
+ if (x > a) {
+ x = a;
+ }
+ }
+
+ int n, m;
+ const A &a;
+ std::vector<TP> row_min, stack;
+ std::vector<int> cols;
+};
diff --git a/test/smawk.hpp b/test/smawk.hpp
index 89d9eef..df43dc4 100644
--- a/test/smawk.hpp
+++ b/test/smawk.hpp
@@ -1,22 +1,20 @@
#include "smawk.h"
-#include "debug.h"
-
#include <bits/stdc++.h>
#include <catch2/catch_all.hpp>
-struct MongeCompare {
- int operator()(int r, int c0, int c1) const {
- return monge[r][c0] - monge[r][c1];
- }
+struct Monge {
+ using E = int;
+
+ E operator()(int x, int y) const { return monge.at(x).at(y); }
std::vector<std::vector<int>> monge;
};
TEST_CASE("smawk") {
- auto n = GENERATE(range(1, 100));
- auto m = GENERATE(range(1, 100));
+ auto n = GENERATE(range(1, 50));
+ auto m = GENERATE(range(1, 50));
std::mt19937 gen{Catch::getSeed()};
std::vector monge(n, std::vector<int>(m));
@@ -35,13 +33,13 @@ TEST_CASE("smawk") {
}
}
- auto row_min = smawk(n, m, MongeCompare{monge});
+ auto row_min = SMAWK(n, m, Monge{monge}).row_min;
for (int i = 0; i < n; i++) {
std::pair<int, int> best{INT_MAX, 0};
for (int j = 0; j < m; j++) {
best = std::min(best, {monge[i][j], j});
}
- REQUIRE(row_min[i] == best.second);
+ REQUIRE(row_min[i] == best);
}
}