diff options
author | Xiaoxu Guo <ftiasch0@gmail.com> | 2024-04-12 16:00:54 +0800 |
---|---|---|
committer | Xiaoxu Guo <ftiasch0@gmail.com> | 2024-04-12 16:00:54 +0800 |
commit | f88e38b0c3e6397bc313067f4572cf40bed90b69 (patch) | |
tree | 6dccea699452098967da5f151021400337f7fe67 | |
parent | 8c1f654178f03334b9201a09ca66e3e2e93159bc (diff) | |
download | shoka-f88e38b0c3e6397bc313067f4572cf40bed90b69.tar.gz shoka-f88e38b0c3e6397bc313067f4572cf40bed90b69.zip |
fixed smawk
-rw-r--r-- | smawk.h | 77 | ||||
-rw-r--r-- | test/smawk.hpp | 18 |
2 files changed, 85 insertions, 10 deletions
@@ -0,0 +1,77 @@ +#include <concepts> +#include <numeric> +#include <ranges> +#include <span> +#include <type_traits> +#include <vector> + +#include "debug.h" + +template <typename A> +concept IsTM = requires(A m) { + { + m(std::declval<int>(), std::declval<int>()) + } -> std::convertible_to<typename A::E>; + { + std::declval<typename A::E>() > std::declval<typename A::E>() + } -> std::convertible_to<bool>; +}; + +template <IsTM A> struct SMAWK { + explicit SMAWK(int n_, int m_, const A &a_) + : n{n_}, m{m_}, a{a_}, row_min(n), cols(m + 2 * n) { + stack.reserve(n); + std::iota(cols.begin(), cols.begin() + m, 0); + recur(0, 0, m); + } + + void recur(int k, int begin, int end) { + if (n < (2 << k)) { + auto r = (1 << k) - 1; + row_min[r] = get(r + 1, cols[begin]); + for (int i = begin + 1; i < end; i++) { + check_min(row_min[r], get(r + 1, cols[i])); + } + } else { + stack.clear(); + for (int i = begin; i < end; i++) { + auto c = cols[i]; + while (!stack.empty() && stack.back() > get((stack.size() << k), c)) { + stack.pop_back(); + } + auto r1 = (stack.size() + 1) << k; + if (r1 <= n) { + stack.push_back(get(r1, c)); + } + } + for (int i = 0; i < stack.size(); i++) { + cols[end + i] = stack[i].second; + } + begin = end, end += stack.size(); + recur(k + 1, begin, end); + auto offset = 1 << k; + for (int r = offset - 1, p = begin; r < n; r += offset << 1) { + auto high = r + offset < n ? row_min[r + offset].second + 1 : m; + row_min[r] = get(r + 1, cols[p]); + while (p + 1 < end && cols[p + 1] < high) { + check_min(row_min[r], get(r + 1, cols[++p])); + } + } + } + } + + using TP = std::pair<typename A::E, int>; + + TP get(int x1, int y) { return {a(x1 - 1, y), y}; } + + static void check_min(TP &x, TP a) { + if (x > a) { + x = a; + } + } + + int n, m; + const A &a; + std::vector<TP> row_min, stack; + std::vector<int> cols; +}; diff --git a/test/smawk.hpp b/test/smawk.hpp index 89d9eef..df43dc4 100644 --- a/test/smawk.hpp +++ b/test/smawk.hpp @@ -1,22 +1,20 @@ #include "smawk.h" -#include "debug.h" - #include <bits/stdc++.h> #include <catch2/catch_all.hpp> -struct MongeCompare { - int operator()(int r, int c0, int c1) const { - return monge[r][c0] - monge[r][c1]; - } +struct Monge { + using E = int; + + E operator()(int x, int y) const { return monge.at(x).at(y); } std::vector<std::vector<int>> monge; }; TEST_CASE("smawk") { - auto n = GENERATE(range(1, 100)); - auto m = GENERATE(range(1, 100)); + auto n = GENERATE(range(1, 50)); + auto m = GENERATE(range(1, 50)); std::mt19937 gen{Catch::getSeed()}; std::vector monge(n, std::vector<int>(m)); @@ -35,13 +33,13 @@ TEST_CASE("smawk") { } } - auto row_min = smawk(n, m, MongeCompare{monge}); + auto row_min = SMAWK(n, m, Monge{monge}).row_min; for (int i = 0; i < n; i++) { std::pair<int, int> best{INT_MAX, 0}; for (int j = 0; j < m; j++) { best = std::min(best, {monge[i][j], j}); } - REQUIRE(row_min[i] == best.second); + REQUIRE(row_min[i] == best); } } |