Skip to content

Commit

Permalink
feat: conv_mval
Browse files Browse the repository at this point in the history
  • Loading branch information
Tiphereth-A committed Mar 2, 2025
1 parent 9f28ae5 commit 2743d39
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 0 deletions.
3 changes: 3 additions & 0 deletions config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,9 @@ notebook:
- conv_dft: 卷积(FFT 或 NTT)
code_ext: hpp
test_ext: cpp
- conv_mval_dft: 卷积(多维,FFT 或 NTT)
code_ext: hpp
test_ext: cpp
- conv_u64: 卷积(u64)
code_ext: hpp
test_ext: cpp
Expand Down
51 changes: 51 additions & 0 deletions src/code/conv/conv_mval_dft.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#ifndef TIFALIBS_CONV_CONV_MVAL_DFT
#define TIFALIBS_CONV_CONV_MVAL_DFT

#include "../util/traits.hpp"

namespace tifa_libs::math {

// @param l l[j]=$\sum l_j \prod_i x_i^{j_i}$, where $j=\sum_i j_i n_i$
// @param r similar as {@code l} with same length
// @param base base[x]=$n_i$
template <dft_c DFT_t, std::same_as<TPN DFT_t::data_t> DFT_data_t>
CEXP vec<DFT_data_t> conv_mval_dft(DFT_t &dft, vec<DFT_data_t> CR l, vec<DFT_data_t> CR r, vecu CR base) NE {
assert(l.size() == r.size());
u32 k = (u32)base.size();
if (!k) return {l[0] * r[0]};
dft.bzr((u32)l.size() * 2 - 1);
u32 n = dft.size();
vecu chi(n);
flt_ (u32, i, 0, n, x) {
x = i;
for (auto b : base) chi[i] += (x /= b);
chi[i] %= k;
}
vvec<DFT_data_t> f(k, vec<DFT_data_t>(n)), g(k, vec<DFT_data_t>(n));
flt_ (u32, i, 0, (u32)l.size()) f[chi[i]][i] = l[i];
flt_ (u32, i, 0, (u32)l.size()) g[chi[i]][i] = r[i];
for (auto &i : f) dft.dif(i);
for (auto &i : g) dft.dif(i);
vec<DFT_data_t> _(k);
flt_ (u32, l, 0, n) {
fill(_, DFT_data_t{});
flt_ (u32, i, 0, k)
flt_ (u32, j, 0, k) _[i + j - (i + j >= k ? k : 0)] += f[i][l] * g[j][l];
flt_ (u32, i, 0, k) f[i][l] = _[i];
}
for (auto &i : f) dft.dit(i);
vec<DFT_data_t> ans(l.size());
flt_ (u32, i, 0, (u32)ans.size()) ans[i] = f[chi[i]][i];
return ans;
}
template <class DFT_t, class mint, class T = u64>
CEXP vec<mint> conv_mval_dft_um(DFT_t &dft, vec<T> CR l, vec<T> CR r, vecu CR base) NE {
vec<mint> l_, r_;
for (l_.reserve(l.size()); auto CR i : l) l_.push_back(i);
for (r_.reserve(r.size()); auto CR i : r) r_.push_back(i);
return conv_mval_dft(dft, l_, r_, base);
}

} // namespace tifa_libs::math

#endif
4 changes: 4 additions & 0 deletions src/doc_md/conv/conv_mval_dft.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
title: conv_mval_dft
documentation_of: //src/code/conv/conv_mval_dft.hpp
---
5 changes: 5 additions & 0 deletions src/doc_tex/conv/conv_mval_dft.tex
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
计算 \(f(x_1,\dots,x_k)g(x_1,\dots,x_k)\bmod{\left(x_1^{n_1},\dots,x_k^{n_k}\right)}\)

\paragraph{复杂度} \(O(kN\log N)\), 其中 \(N=\prod_i n_i\)

\paragraph{参考链接} \qrcode{https://rushcheyo.blog.uoj.ac/blog/6547}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#define PROBLEM "https://judge.yosupo.jp/problem/multivariate_convolution"

#include "../../code/conv/conv_mval_dft.hpp"
#include "../../code/conv/ntt.hpp"
#include "../../code/io/fastin.hpp"
#include "../../code/io/fastout.hpp"
#include "../../code/io/ios_container.hpp"

CEXP u32 MOD = 998244353;

// ---<GENTC>--- begin
// ---<GENTC>--- append mints
// ---<GENTC>--- end

using ntt_t = tifa_libs::math::NTT<mint>;
using vec_t = vec<mint>;

int main() {
u32 k, n = 1;
tifa_libs::fin >> k;
vecu base(k);
for (auto &i : base) tifa_libs::fin >> i, n *= i;
vec_t a(n), b(n);
tifa_libs::fin >> a >> b;
ntt_t ntt;
tifa_libs::fout << tifa_libs::math::conv_mval_dft(ntt, a, b, base) << '\n';
return 0;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#define AUTO_GENERATED
#define PROBLEM "https://judge.yosupo.jp/problem/multivariate_convolution"

#include "../../code/conv/conv_mval_dft.hpp"
#include "../../code/conv/ntt.hpp"
#include "../../code/io/fastin.hpp"
#include "../../code/io/fastout.hpp"
#include "../../code/io/ios_container.hpp"

CEXP u32 MOD = 998244353;

#include "../../code/math/mint.hpp"
#include "../../code/math/mint_s30.hpp"

using mint = tifa_libs::math::mint<tifa_libs::math::mint_s30, MOD>;

using ntt_t = tifa_libs::math::NTT<mint>;
using vec_t = vec<mint>;

int main() {
u32 k, n = 1;
tifa_libs::fin >> k;
vecu base(k);
for (auto &i : base) tifa_libs::fin >> i, n *= i;
vec_t a(n), b(n);
tifa_libs::fin >> a >> b;
ntt_t ntt;
tifa_libs::fout << tifa_libs::math::conv_mval_dft(ntt, a, b, base) << '\n';
return 0;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#define AUTO_GENERATED
#define PROBLEM "https://judge.yosupo.jp/problem/multivariate_convolution"

#include "../../code/conv/conv_mval_dft.hpp"
#include "../../code/conv/ntt.hpp"
#include "../../code/io/fastin.hpp"
#include "../../code/io/fastout.hpp"
#include "../../code/io/ios_container.hpp"

CEXP u32 MOD = 998244353;

#include "../../code/math/mint.hpp"
#include "../../code/math/mint_s63.hpp"

using mint = tifa_libs::math::mint<tifa_libs::math::mint_s63, MOD>;

using ntt_t = tifa_libs::math::NTT<mint>;
using vec_t = vec<mint>;

int main() {
u32 k, n = 1;
tifa_libs::fin >> k;
vecu base(k);
for (auto &i : base) tifa_libs::fin >> i, n *= i;
vec_t a(n), b(n);
tifa_libs::fin >> a >> b;
ntt_t ntt;
tifa_libs::fout << tifa_libs::math::conv_mval_dft(ntt, a, b, base) << '\n';
return 0;
}
Empty file.

0 comments on commit 2743d39

Please sign in to comment.