Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 64 additions & 11 deletions include/boost/multi/adaptors/cufft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <array>
#include <cstddef>
#include <exception>
#include <map>
#include <stdexcept>
#include <tuple>
Expand Down Expand Up @@ -261,7 +262,7 @@ class plan {

std::sort(which_iodims_.begin() + first_howmany_, which_iodims_.begin() + D, [](auto const& alpha, auto const& omega) { return get<1>(alpha).n > get<1>(omega).n; });

if(first_howmany_ <= D - 1) {
if(first_howmany_ == D - 1) {
if constexpr(std::is_same_v<Alloc, void*>) { // NOLINT(bugprone-branch-clone) workaround bug in DeepSource
cufftSafeCall(::cufftPlanMany(
/*cufftHandle *plan*/ &h_,
Expand Down Expand Up @@ -303,12 +304,69 @@ class plan {
++first_howmany_;
return;
}

if(first_howmany_ <= D - 2) {

int nstreams = which_iodims_[first_howmany_].second.n;
std::vector<cudaStream_t> streams(nstreams);
for(auto& s : streams) {
cudaStreamCreate(&s) == cudaSuccess ?0:throw std::runtime_error{"Failed to create CUDA stream"};
}
std::vector<cufftHandle> plans(nstreams);

std::vector<::size_t> worksizes(nstreams);
std::vector<void*> workareas(nstreams);

for(int idx = 0; idx != nstreams; ++idx) {
if constexpr(std::is_same_v<Alloc, void*>) { // NOLINT(bugprone-branch-clone) workaround bug in DeepSource
std::terminate();
cufftSafeCall(::cufftPlanMany(
/*cufftHandle *plan*/ &plans[idx],
/*int rank*/ dims_end - dims.begin(),
/*int *n*/ ion.data(),
/*int *inembed*/ inembed.data(),
/*int istride*/ istride,
/*int idist*/ which_iodims_[first_howmany_].second.is,
/*int *onembed*/ onembed.data(),
/*int ostride*/ ostride,
/*int odist*/ which_iodims_[first_howmany_].second.os,
/*cufftType type*/ CUFFT_Z2Z,
/*int batch*/ which_iodims_[first_howmany_].second.n
));
} else {
std::terminate();
cufftSafeCall(cufftCreate(&plans[idx]));
cufftSafeCall(cufftSetAutoAllocation(plans[idx], false));
cufftSafeCall(cufftMakePlanMany(
/*cufftHandle *plan*/ plans[idx],
/*int rank*/ dims_end - dims.begin(),
/*int *n*/ ion.data(),
/*int *inembed*/ inembed.data(),
/*int istride*/ istride,
/*int idist*/ which_iodims_[first_howmany_].second.is,
/*int *onembed*/ onembed.data(),
/*int ostride*/ ostride,
/*int odist*/ which_iodims_[first_howmany_].second.os,
/*cufftType type*/ CUFFT_Z2Z,
/*int batch*/ which_iodims_[first_howmany_].second.n,
/*size_t **/ &workSize_
));
cufftSafeCall(cufftGetSize(plans[idx], &worksizes[idx]));
workareas[idx] = ::thrust::raw_pointer_cast(alloc_.allocate(worksizes[idx]));
cufftSafeCall(cufftSetWorkArea(plans[idx], workareas[idx]));
}
if(!plans[idx]) { throw std::runtime_error{"cufftPlanMany null"}; }
}
++first_howmany_;
return;
}

// throw std::runtime_error{"cufft not implemented yet"};
}

private:
template<typename = void>
void ExecZ2Z_(complex_type const* idata, complex_type* odata, int direction) const {
void ExecZ2Z_(complex_type const* idata, complex_type* odata, int direction) {
// used_ = true;
cufftSafeCall(cufftExecZ2Z(h_, const_cast<complex_type*>(idata), odata, direction)); // NOLINT(cppcoreguidelines-pro-type-const-cast) wrap legacy interface
// cudaDeviceSynchronize();
Expand Down Expand Up @@ -353,6 +411,7 @@ class plan {

for(int idx = 0; idx != which_iodims_[first_howmany_].second.n; ++idx) { // NOLINT(altera-unroll-loops,altera-unroll-loops,altera-id-dependent-backward-branch) TODO(correaa) use an algorithm
for(int jdx = 0; jdx != which_iodims_[first_howmany_ + 1].second.n; ++jdx) { // NOLINT(altera-unroll-loops,altera-unroll-loops,altera-id-dependent-backward-branch) TODO(correaa) use an algorithm
throw std::runtime_error{"complicated loop"};
cufftExecZ2Z(
h_,
const_cast<complex_type*>(reinterpret_cast<complex_type const*>(::thrust::raw_pointer_cast(idata + idx * which_iodims_[first_howmany_].second.is + jdx * which_iodims_[first_howmany_ + 1].second.is))), // NOLINT(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-type-reinterpret-cast) legacy interface
Expand All @@ -366,17 +425,11 @@ class plan {
throw std::runtime_error{"error2"};
}

template<class IPtr, class OPtr>
void execute_forward(IPtr idata, OPtr odata) { // TODO(correaa) make const
execute(idata, odata, cufft::forward);
}
template<class IPtr, class OPtr>
void execute_backward(IPtr idata, OPtr odata) { // TODO(correaa) make const
execute(idata, odata, cufft::backward);
}
template<class IPtr, class OPtr> void execute_forward(IPtr idata, OPtr odata) { execute(idata, odata, cufft::forward); }
template<class IPtr, class OPtr> void execute_backward(IPtr idata, OPtr odata) { execute(idata, odata, cufft::backward); }

template<class IPtr, class OPtr>
void operator()(IPtr idata, OPtr odata, int direction) const {
void operator()(IPtr idata, OPtr odata, int direction) {
// used_ = true;
ExecZ2Z_(reinterpret_cast<complex_type const*>(::thrust::raw_pointer_cast(idata)), reinterpret_cast<complex_type*>(::thrust::raw_pointer_cast(odata)), direction); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast) legacy interface
}
Expand Down