aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiaoxu Guo <shimo11370@proton.me>2024-07-26 10:15:24 +0800
committerXiaoxu Guo <shimo11370@proton.me>2024-07-26 10:15:24 +0800
commitfc1c261147275c9c2c3931077a362a20fac49020 (patch)
treef55ff7962b238e6c2a23c6ac578a1651eae60179
parentb1e82d19303ae440a7f9ff2989582424735f797e (diff)
downloadshoka-fc1c261147275c9c2c3931077a362a20fac49020.tar.gz
shoka-fc1c261147275c9c2c3931077a362a20fac49020.zip
added NELatticPath
-rw-r--r--north_east_lattic_path.h87
-rw-r--r--poly_conv.h56
2 files changed, 143 insertions, 0 deletions
diff --git a/north_east_lattic_path.h b/north_east_lattic_path.h
new file mode 100644
index 0000000..fe5f1ad
--- /dev/null
+++ b/north_east_lattic_path.h
@@ -0,0 +1,87 @@
+#include "binom.h"
+#include "poly_conv.h"
+
+#include <algorithm>
+#include <span>
+#include <vector>
+
+/*
+ * Compute R[j] = sum_i U[i] * W(i, j) for j = 0, ..., m - 1.
+ * Here, W(i, j) is the number of north-easy walk from (i, 0) to (n - 1, j)
+ * where y < A[x] holds
+ */
+template <typename Mod> struct NorthEastLatticePath {
+ explicit NorthEastLatticePath(const std::vector<int> &a_,
+ const std::vector<Mod> &U_, int m)
+ : a{a_}, U{U_}, R(m), binom(a.size() + m) {
+ recur(0, a_.size(), 0, m);
+ }
+
+ const std::vector<Mod> &result() const { return R; }
+
+private:
+ void recur(int l, int r, int lo, int hi) {
+ // Invariance: U[lo, high) = 0
+ if (l + 1 == r) {
+ for (int j = lo; j < hi && j < a[l]; j++) {
+ R[j] = U[l];
+ }
+ } else {
+ int m = (l + r) >> 1;
+ int mi = a[m];
+ if (lo < mi) {
+ recur(l, m, lo, mi);
+ rect(std::span{U.begin() + m, U.begin() + r},
+ std::span(R.begin() + lo, R.begin() + mi));
+ }
+ if (mi < hi) {
+ recur(m, r, mi, hi);
+ }
+ }
+ }
+
+ void rect(std::span<Mod> U, std::span<Mod> R) {
+ tmp_U.assign(U.begin(), U.end());
+ tmp_R.assign(R.begin(), R.end());
+ std::ranges::fill(U, Mod{0});
+ std::ranges::fill(R, Mod{0});
+ rect_adj(tmp_U, R);
+ rect_adj(tmp_R, U);
+ rect_op(tmp_U, U, R.size());
+ rect_op(tmp_R, R, U.size());
+ }
+
+ void rect_adj(const std::vector<Mod> &U, std::span<Mod> R) {
+ int n = U.size();
+ int m = R.size();
+ lhs.resize(n);
+ for (int i = 0; i < n; i++) {
+ lhs[i] = U[i] * binom.inv_fact[n - 1 - i];
+ }
+ rhs.resize(n + m - 1);
+ for (int i = 0; i < n + m - 1; i++) {
+ rhs[i] = binom.fact[i];
+ }
+ conv(out, lhs, rhs, n + m);
+ for (int i = 0; i < m; i++) {
+ R[i] += out[n - 1 + i] * binom.inv_fact[i];
+ }
+ }
+
+ void rect_op(const std::vector<Mod> &U, std::span<Mod> UU, int h) {
+ int n = U.size();
+ rhs.resize(n);
+ for (int i = 0; i < n; i++) {
+ rhs[i] = binom.fact[h - 1 + i] * binom.inv_fact[i];
+ }
+ conv(out, U, rhs);
+ for (int i = 0; i < n; i++) {
+ UU[i] += out[i] * binom.inv_fact[h - 1];
+ }
+ }
+
+ const std::vector<int> &a;
+ std::vector<Mod> U, R, tmp_U, tmp_R, lhs, rhs, out;
+ Binom<Mod> binom;
+ PolyConv<Mod> conv;
+};
diff --git a/poly_conv.h b/poly_conv.h
new file mode 100644
index 0000000..ed1dcce
--- /dev/null
+++ b/poly_conv.h
@@ -0,0 +1,56 @@
+#pragma once
+
+#include "ntt.h"
+#include "singleton.h"
+#include "snippets/min_pow_of_two.h"
+
+#include <cstdlib>
+#include <vector>
+
+template <typename Mod> struct PolyConv {
+ using Vector = std::vector<Mod>;
+
+ void operator()(Vector &out, const Vector &lhs, const Vector &rhs,
+ int limit = std::numeric_limits<int>::max()) const {
+ int deg_plus_1 = std::min(limit, static_cast<int>(lhs.size()) +
+ static_cast<int>(rhs.size()) - 1);
+ if (deg_plus_1 <= 16) {
+ out.assign(deg_plus_1, Mod{0});
+ for (int i = 0; i < lhs.size(); ++i) {
+ for (int j = 0; j < rhs.size() && i + j < limit; ++j) {
+ out[i + j] += lhs[i] * rhs[j];
+ }
+ }
+ }
+ int n = min_pow_of_two(deg_plus_1);
+ ntt().reserve(n);
+ Mod *b0 = ntt().template raw_buffer<0>();
+ Mod *b1 = ntt().template raw_buffer<1>();
+ copy_and_fill0(n, b0, lhs);
+ ntt().dif(n, b0);
+ copy_and_fill0(n, b1, rhs);
+ ntt().dif(n, b1);
+ out.resize(n);
+ auto inv_n = ntt().power_of_two_inv(n);
+ for (int i = 0; i < n; ++i) {
+ out[i] = inv_n * b0[i] * b1[i];
+ }
+ ntt().dit(n, out.data());
+ out.resize(deg_plus_1);
+ }
+
+private:
+ using Ntt = NttT<Mod>;
+
+ static Ntt &ntt() { return singleton<Ntt>(); }
+
+ static void copy_and_fill0(int n, Mod *dst, int m, const Mod *src) {
+ m = std::min(n, m);
+ std::copy(src, src + m, dst);
+ std::fill(dst + m, dst + n, Mod{0});
+ }
+
+ static void copy_and_fill0(int n, Mod *dst, const std::vector<Mod> &src) {
+ copy_and_fill0(n, dst, src.size(), src.data());
+ }
+};