diff --git a/include/boost/multi/adaptors/cufft.hpp b/include/boost/multi/adaptors/cufft.hpp index b3ac00327..0b765e74b 100644 --- a/include/boost/multi/adaptors/cufft.hpp +++ b/include/boost/multi/adaptors/cufft.hpp @@ -13,6 +13,7 @@ #include #include +#include #include #include #include @@ -261,7 +262,7 @@ class plan { std::sort(which_iodims_.begin() + first_howmany_, which_iodims_.begin() + D, [](auto const& alpha, auto const& omega) { return get<1>(alpha).n > get<1>(omega).n; }); - if(first_howmany_ <= D - 1) { + if(first_howmany_ == D - 1) { if constexpr(std::is_same_v) { // NOLINT(bugprone-branch-clone) workaround bug in DeepSource cufftSafeCall(::cufftPlanMany( /*cufftHandle *plan*/ &h_, @@ -303,12 +304,69 @@ class plan { ++first_howmany_; return; } + + if(first_howmany_ <= D - 2) { + + int nstreams = which_iodims_[first_howmany_].second.n; + std::vector streams(nstreams); + for(auto& s : streams) { + cudaStreamCreate(&s) == cudaSuccess ?0:throw std::runtime_error{"Failed to create CUDA stream"}; + } + std::vector plans(nstreams); + + std::vector<::size_t> worksizes(nstreams); + std::vector workareas(nstreams); + + for(int idx = 0; idx != nstreams; ++idx) { + if constexpr(std::is_same_v) { // NOLINT(bugprone-branch-clone) workaround bug in DeepSource + std::terminate(); + cufftSafeCall(::cufftPlanMany( + /*cufftHandle *plan*/ &plans[idx], + /*int rank*/ dims_end - dims.begin(), + /*int *n*/ ion.data(), + /*int *inembed*/ inembed.data(), + /*int istride*/ istride, + /*int idist*/ which_iodims_[first_howmany_].second.is, + /*int *onembed*/ onembed.data(), + /*int ostride*/ ostride, + /*int odist*/ which_iodims_[first_howmany_].second.os, + /*cufftType type*/ CUFFT_Z2Z, + /*int batch*/ which_iodims_[first_howmany_].second.n + )); + } else { + std::terminate(); + cufftSafeCall(cufftCreate(&plans[idx])); + cufftSafeCall(cufftSetAutoAllocation(plans[idx], false)); + cufftSafeCall(cufftMakePlanMany( + /*cufftHandle *plan*/ plans[idx], + /*int rank*/ dims_end - dims.begin(), + /*int *n*/ ion.data(), + /*int *inembed*/ inembed.data(), + /*int istride*/ istride, + /*int idist*/ which_iodims_[first_howmany_].second.is, + /*int *onembed*/ onembed.data(), + /*int ostride*/ ostride, + /*int odist*/ which_iodims_[first_howmany_].second.os, + /*cufftType type*/ CUFFT_Z2Z, + /*int batch*/ which_iodims_[first_howmany_].second.n, + /*size_t **/ &workSize_ + )); + cufftSafeCall(cufftGetSize(plans[idx], &worksizes[idx])); + workareas[idx] = ::thrust::raw_pointer_cast(alloc_.allocate(worksizes[idx])); + cufftSafeCall(cufftSetWorkArea(plans[idx], workareas[idx])); + } + if(!plans[idx]) { throw std::runtime_error{"cufftPlanMany null"}; } + } + ++first_howmany_; + return; + } + // throw std::runtime_error{"cufft not implemented yet"}; } private: template - void ExecZ2Z_(complex_type const* idata, complex_type* odata, int direction) const { + void ExecZ2Z_(complex_type const* idata, complex_type* odata, int direction) { // used_ = true; cufftSafeCall(cufftExecZ2Z(h_, const_cast(idata), odata, direction)); // NOLINT(cppcoreguidelines-pro-type-const-cast) wrap legacy interface // cudaDeviceSynchronize(); @@ -353,6 +411,7 @@ class plan { for(int idx = 0; idx != which_iodims_[first_howmany_].second.n; ++idx) { // NOLINT(altera-unroll-loops,altera-unroll-loops,altera-id-dependent-backward-branch) TODO(correaa) use an algorithm for(int jdx = 0; jdx != which_iodims_[first_howmany_ + 1].second.n; ++jdx) { // NOLINT(altera-unroll-loops,altera-unroll-loops,altera-id-dependent-backward-branch) TODO(correaa) use an algorithm + throw std::runtime_error{"complicated loop"}; cufftExecZ2Z( h_, const_cast(reinterpret_cast(::thrust::raw_pointer_cast(idata + idx * which_iodims_[first_howmany_].second.is + jdx * which_iodims_[first_howmany_ + 1].second.is))), // NOLINT(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-type-reinterpret-cast) legacy interface @@ -366,17 +425,11 @@ class plan { throw std::runtime_error{"error2"}; } - template - void execute_forward(IPtr idata, OPtr odata) { // TODO(correaa) make const - execute(idata, odata, cufft::forward); - } - template - void execute_backward(IPtr idata, OPtr odata) { // TODO(correaa) make const - execute(idata, odata, cufft::backward); - } + template void execute_forward(IPtr idata, OPtr odata) { execute(idata, odata, cufft::forward); } + template void execute_backward(IPtr idata, OPtr odata) { execute(idata, odata, cufft::backward); } template - void operator()(IPtr idata, OPtr odata, int direction) const { + void operator()(IPtr idata, OPtr odata, int direction) { // used_ = true; ExecZ2Z_(reinterpret_cast(::thrust::raw_pointer_cast(idata)), reinterpret_cast(::thrust::raw_pointer_cast(odata)), direction); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast) legacy interface }