Skip to content

Commit

Permalink
refactor: better conv_u64 & CRT
Browse files Browse the repository at this point in the history
  • Loading branch information
Tiphereth-A committed Feb 21, 2025
1 parent 67dbfee commit 0a0587a
Show file tree
Hide file tree
Showing 53 changed files with 380 additions and 183 deletions.
16 changes: 14 additions & 2 deletions config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,9 @@ notebook:
- gint: Gauss 整数
code_ext: hpp
test_ext: cpp
- eint: Eisenstein 整数
code_ext: hpp
test_ext: cpp
- mint: 模整数类基类
code_ext: hpp
test_ext: cpp
Expand Down Expand Up @@ -1056,7 +1059,10 @@ notebook:
code_ext: hpp
test_ext: cpp
conv:
- fft: 快速 Fourier 变换
- fft_r2: 快速 Fourier 变换(Radix 2)
code_ext: hpp
test_ext: cpp
- fft_r3: 快速 Fourier 变换(Radix 3)
code_ext: hpp
test_ext: cpp
- ntt: 数论变换
Expand All @@ -1068,9 +1074,15 @@ notebook:
- conv_naive: 卷积(暴力)
code_ext: hpp
test_ext: cpp
- conv_naive_mod: 卷积(暴力,mod)
code_ext: hpp
test_ext: cpp
- conv_dft: 卷积(FFT 或 NTT)
code_ext: hpp
test_ext: cpp
- conv_u64: 卷积(u64)
code_ext: hpp
test_ext: cpp
- conv_mtt: 卷积(MTT)
code_ext: hpp
test_ext: cpp
Expand All @@ -1086,7 +1098,7 @@ notebook:
- conv_czt: 卷积(Chirp-Z 变换)
code_ext: hpp
test_ext: cpp
- karatsuba: Karatsuba 乘法
- convcyc_naive: 循环卷积(暴力)
code_ext: hpp
test_ext: cpp
- wht: Walsh Hadamard 变换
Expand Down
7 changes: 4 additions & 3 deletions src/code/conv/conv_3ntt.hpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
#ifndef TIFALIBS_CONV_CONV_3NTT
#define TIFALIBS_CONV_CONV_3NTT

#include "../math/mul_mod.hpp"
#include "conv_dft.hpp"
#include "conv_naive_mod.hpp"
#include "ntt.hpp"

namespace tifa_libs::math {

// 167772161, 469762049, 754974721
template <class mint0, class mint1, class mint2>
CEXP vecuu conv_3ntt_u64(NTT<mint0> &ntt0, NTT<mint1> &ntt1, NTT<mint2> &ntt2, vecuu CR l, vecuu CR r, u64 mod, u32 ans_size = 0) NE {
if (!ans_size) ans_size = u32(l.size() + r.size() - 1);
if (min(l.size(), r.size()) < CONV_NAIVE_MOD_THRESHOLD) return conv_naive_mod(l, r, mod, ans_size);
CEXP u64 m0 = mint0::mod(), m1 = mint1::mod(), m2 = mint2::mod();
const u64 r01 = mint1(m0).inv().val(), r02 = mint2(m0).inv().val(), r12 = mint2(m1).inv().val(),
r02r12 = (u32)mul_mod_u(r02, r12, m2),
w1 = m0 % mod, w2 = mul_mod_u(m0, m1, mod);
if (!ans_size) ans_size = u32(l.size() + r.size() - 1);
const vec<mint0> d0 = conv_dft_um<NTT<mint0>, mint0>(ntt0, l, r, ans_size);
const vec<mint1> d1 = conv_dft_um<NTT<mint1>, mint1>(ntt1, l, r, ans_size);
const vec<mint2> d2 = conv_dft_um<NTT<mint2>, mint2>(ntt2, l, r, ans_size);
Expand All @@ -30,7 +31,7 @@ CEXP vecuu conv_3ntt_u64(NTT<mint0> &ntt0, NTT<mint1> &ntt1, NTT<mint2> &ntt2, v
template <class mint, class mint0, class mint1, class mint2>
CEXP vec<mint> conv_3ntt(NTT<mint0> &ntt0, NTT<mint1> &ntt1, NTT<mint2> &ntt2, vec<mint> CR l, vec<mint> CR r, u32 ans_size = 0) NE {
if (!ans_size) ans_size = u32(l.size() + r.size() - 1);
if (ans_size < 32) return conv_naive(l, r, ans_size);
if (min(l.size(), r.size()) < CONV_NAIVE_THRESHOLD) return conv_naive(l, r, ans_size);
vecuu l_(l.size()), r_(r.size());
flt_ (u32, i, 0, (u32)l.size()) l_[i] = l[i].val();
flt_ (u32, i, 0, (u32)r.size()) r_[i] = r[i].val();
Expand Down
2 changes: 1 addition & 1 deletion src/code/conv/conv_dft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace tifa_libs::math {
template <dft_c DFT_t, std::same_as<TPN DFT_t::data_t> DFT_data_t>
CEXP vec<DFT_data_t> conv_dft(DFT_t &dft, vec<DFT_data_t> l, vec<DFT_data_t> r, u32 ans_size = 0) NE {
if (!ans_size) ans_size = u32(l.size() + r.size() - 1);
if (ans_size < 32) return conv_naive(l, r, ans_size);
if (min(l.size(), r.size()) < CONV_NAIVE_THRESHOLD) return conv_naive(l, r, ans_size);
dft.bzr(max({(u32)l.size(), (u32)r.size(), min(u32(l.size() + r.size() - 1), ans_size)}));
dft.dif(l), dft.dif(r);
flt_ (u32, i, 0, dft.size()) l[i] *= r[i];
Expand Down
8 changes: 4 additions & 4 deletions src/code/conv/conv_mtt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
#define TIFALIBS_CONV_CONV_MTT

#include "conv_naive.hpp"
#include "fft.hpp"
#include "fft_r2.hpp"

namespace tifa_libs::math {

template <class mint, class FP>
CEXP vec<mint> conv_mtt(FFT<FP> &fft, vec<mint> CR l, vec<mint> CR r, u32 ans_size = 0) NE {
using C = TPN FFT<FP>::C;
CEXP vec<mint> conv_mtt(FFT_R2<FP> &fft, vec<mint> CR l, vec<mint> CR r, u32 ans_size = 0) NE {
using C = TPN FFT_R2<FP>::data_t;
if (!ans_size) ans_size = u32(l.size() + r.size() - 1);
if (ans_size < 32) return conv_naive(l, r, ans_size);
if (min(l.size(), r.size()) < CONV_NAIVE_THRESHOLD) return conv_naive(l, r, ans_size);
if (l.size() == 1) {
vec<mint> ans = r;
for (ans.resize(ans_size); auto &i : ans) i *= l[0];
Expand Down
1 change: 1 addition & 0 deletions src/code/conv/conv_naive.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

namespace tifa_libs::math {

CEXP inline u32 CONV_NAIVE_THRESHOLD = 16;
template <class U, class T = U>
requires(sizeof(U) <= sizeof(T))
CEXP vec<T> conv_naive(vec<U> CR l, vec<U> CR r, u32 ans_size = 0) NE {
Expand Down
24 changes: 24 additions & 0 deletions src/code/conv/conv_naive_mod.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef TIFALIBS_CONV_CONV_NAIVE_MOD
#define TIFALIBS_CONV_CONV_NAIVE_MOD

#include "../math/mul_mod.hpp"

namespace tifa_libs::math {

CEXP inline u32 CONV_NAIVE_MOD_THRESHOLD = 16;
CEXP vecuu conv_naive_mod(spnuu l, spnuu r, u64 mod, u32 ans_size = 0) NE {
if (l.empty() || r.empty()) return {};
if (!ans_size) ans_size = u32(l.size() + r.size() - 1);
vecuu ans(ans_size);
u32 n = (u32)l.size(), m = (u32)r.size();
auto &&l_ = n < m ? r : l, &&r_ = n < m ? l : r;
if (n < m) swap(n, m);
flt_ (u32, i, 0, n)
flt_ (u32, j, 0, min(m, ans_size - i)) ans[i + j] += mul_mod_u(l_[i], r_[j], mod);
for (auto &i : ans) i %= mod;
return ans;
}

} // namespace tifa_libs::math

#endif
5 changes: 2 additions & 3 deletions src/code/conv/conv_u128.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace tifa_libs::math {
// max = 167772161 * 469762049 * 754974721 \approx 5.95e25
template <class T>
vec<u128> conv_u128(vec<T> CR l, vec<T> CR r, u32 ans_size = 0) NE {
if (!ans_size) ans_size = u32(l.size() + r.size() - 1);
if (min(l.size(), r.size()) < CONV_NAIVE_THRESHOLD) return conv_naive<T, u128>(l, r, ans_size);
static CEXP u32 m0 = 167772161, m1 = 469762049, m2 = 754974721;
using mint0 = mint<mint_s30, m0>;
using mint1 = mint<mint_s30, m1>;
Expand All @@ -21,9 +23,6 @@ vec<u128> conv_u128(vec<T> CR l, vec<T> CR r, u32 ans_size = 0) NE {
r12 = inverse(m1, mint2::mod()),
r02r12 = (u64)r02 * r12 % m2;
static CEXP u64 w1 = m0, w2 = (u64)m0 * m1;
if (!ans_size) ans_size = u32(l.size() + r.size() - 1);
if (l.empty() && r.empty()) return {};
if (min(l.size(), r.size()) < 128) return conv_naive<T, u128>(l, r, ans_size);
static NTT<mint0> ntt0;
static NTT<mint1> ntt1;
static NTT<mint2> ntt2;
Expand Down
76 changes: 76 additions & 0 deletions src/code/conv/conv_u64.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#ifndef TIFALIBS_CONV_CONV_u64
#define TIFALIBS_CONV_CONV_u64

#include "conv_naive.hpp"
#include "fft_r3.hpp"

namespace tifa_libs::math {

template <class T>
vecuu conv_u64(vec<T> CR a, vec<T> CR b, u32 ans_size = 0) NE {
const u32 n = (u32)a.size(), m = (u32)b.size();
if (!ans_size) ans_size = n + m - 1;
if (a.empty() && b.empty()) return {};
if (min(n, m) < CONV_NAIVE_THRESHOLD) return conv_naive<T, u64>(a, b, ans_size);
static FFT_R3<T> fft;
using EI = FFT_R3<T>::data_t;
CEXP static EI inv_3{-T(1) / 3 * 2 + 1, 0};
fft.bzr(n + m - 1);
u32 s = fft.size();
vec<EI> pa(s), pb(s);
for (u32 i = 0; i < std::min(s, n); ++i) pa[i].real(a[i]);
for (u32 i = s; i < std::min(2 * s, n); ++i) pa[i - s].imag(a[i]);
for (u32 i = 0; i < std::min(s, m); ++i) pb[i].real(b[i]);
for (u32 i = s; i < std::min(2 * s, m); ++i) pb[i - s].imag(b[i]);
vec<EI> pc(4 * s);
auto mul = [](auto&& mul, EI* p, EI* q, EI* to, u32 n) {
if (n <= 27) {
std::fill_n(to, n, 0);
flt_ (u32, i, 0, n) {
flt_ (u32, j, 0, n - i) to[i + j] += p[i] * q[j];
flt_ (u32, j, n - i, n) to[i + j - n] += p[i] * q[j] * EI::w;
}
return;
}
u32 m = 1;
for (; m * m < n; m *= 3);
u32 r = n / m;
EI inv{1};
for (u32 i = 1; i < r; i *= 3) inv *= inv_3;
flt_ (u32, i, 0, r) {
fft.twiddle(p + m * i, m, m / r * i, to + m * i);
fft.twiddle(q + m * i, m, m / r * i, to + n + m * i);
}
fft.dif(to, m, r), fft.dif(to + n, m, r);
flt_ (u32, i, 0, r) mul(mul, to + m * i, to + n + m * i, to + 2 * n + m * i, m);
fft.dit(to + 2 * n, m, r);
flt_ (u32, i, 0, n) to[2 * n + i] *= inv;
flt_ (u32, i, 0, r) fft.twiddle(to + 2 * n + m * i, m, 3 * m - m / r * i, to + n + m * i);
flt_ (u32, i, 0, r) {
flt_ (u32, j, 0, m) p[m * i + j] = conj(p[m * i + j]), q[m * i + j] = conj(q[m * i + j]);
fft.twiddle(p + m * i, m, 2 * m / r * i, to + m * i);
fft.twiddle(q + m * i, m, 2 * m / r * i, p + m * i);
}
fft.dif(to, m, r), fft.dif(p, m, r);
flt_ (u32, i, 0, r) mul(mul, to + m * i, p + m * i, to + 2 * n + m * i, m);
fft.dit(to + 2 * n, m, r);
flt_ (u32, i, 0, n) to[2 * n + i] *= inv;
flt_ (u32, i, 0, r) fft.twiddle(to + 2 * n + m * i, m, 3 * m - 2 * m / r * i, q + m * i);
std::fill_n(to, n, 0);
flt_ (u32, i, 0, n) {
to[i] += (1 - EI::w) * to[n + i] + (1 - EI::w2) * conj(q[i]);
if (i + m < n) to[i + m] += (EI::w2 - EI::w) * (to[n + i] - conj(q[i]));
else to[i + m - n] += (1 - EI::w2) * (to[n + i] - conj(q[i]));
}
flt_ (u32, i, 0, n) to[i] *= inv_3;
};
mul(mul, pa.data(), pb.data(), pc.data(), s);
vec<T> ans(ans_size);
flt_ (u32, i, 0, std::min(s, ans_size)) ans[i] = pc[i].real();
flt_ (u32, i, s, std::min(2 * s, ans_size)) ans[i] = pc[i - s].imag();
return ans;
}

} // namespace tifa_libs::math

#endif
25 changes: 25 additions & 0 deletions src/code/conv/convcyc_naive.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef TIFALIBS_CONV_CONVCYC_NAIVE
#define TIFALIBS_CONV_CONVCYC_NAIVE

#include "../util/util.hpp"

namespace tifa_libs::math {

template <class U, class T = U>
requires(sizeof(U) <= sizeof(T))
CEXP vec<T> convcyc_naive(vec<U> CR l, vec<U> CR r) NE {
if (l.empty() || r.empty()) return {};
assert(l.size() == r.size());
const u32 n = (u32)l.size();
vec<T> ans(n);
flt_ (u32, i, 0, n) {
flt_ (u32, j, 0, n - i) ans[i + j] += (T)l[i] * (T)r[j];
flt_ (u32, j, n - i, n) ans[i + j - n] += (T)l[i] * (T)r[j];
}

return ans;
}

} // namespace tifa_libs::math

#endif
25 changes: 12 additions & 13 deletions src/code/conv/fft.hpp → src/code/conv/fft_r2.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
#ifndef TIFALIBS_CONV_FFT
#define TIFALIBS_CONV_FFT
#ifndef TIFALIBS_CONV_FFT_R2
#define TIFALIBS_CONV_FFT_R2

#include "../util/util.hpp"

namespace tifa_libs::math {

template <std::floating_point FP>
struct FFT {
class FFT_R2 {
using C = std::complex<FP>;
const FP TAU = std::acos((FP)-1.) * 2;
vecu rev;
vec<C> w;

public:
using data_t = C;

CEXPE FFT() NE : rev(), w() {}
CEXPE FFT_R2() NE : rev(), w() {}

CEXP u32 size() CNE { return (u32)rev.size(); }
CEXP void bzr(u32 len) NE {
Expand All @@ -22,7 +27,7 @@ struct FFT {
w.resize(n), w[0].real(1);
flt_ (u32, i, 1, n) w[i] = {std::cos(TAU * (FP)i / (FP)n), std::sin(TAU * (FP)i / (FP)n)};
}
CEXP void dif(vec<C> &f, u32 n = 0) CNE {
CEXP void dif(vec<data_t> &f, u32 n = 0) CNE {
if (!n) n = size();
if (f.size() < n) f.resize(n);
assert(n <= size());
Expand All @@ -34,23 +39,17 @@ struct FFT {
auto l = f.begin() + j, r = f.begin() + j + i / 2;
auto p = w.begin();
for (u32 k = 0; k < i / 2; ++k, ++l, ++r, p += d) {
const C _ = *r * *p;
const data_t _ = *r * *p;
*r = *l - _, *l = *l + _;
}
}
#pragma GCC diagnostic warning "-Wsign-conversion"
}
CEXP void dit(vec<C> &f, u32 n = 0) CNE {
CEXP void dit(vec<data_t> &f, u32 n = 0) CNE {
if (!n) n = size();
dif(f, n);
flt_ (u32, i, 0, n) f[i] /= (FP)n;
}

private:
const FP TAU = std::acos((FP)-1.) * 2;

vecu rev;
vec<C> w;
};

} // namespace tifa_libs::math
Expand Down
Loading

0 comments on commit 0a0587a

Please sign in to comment.