Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 50 additions & 100 deletions include/RAJA/policy/openmp_target/forall.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,75 +39,21 @@ template<size_t ThreadsPerTeam,
typename ForallParam>
RAJA_INLINE concepts::enable_if_t<
resources::EventProxy<resources::Omp>,
RAJA::expt::type_traits::is_ForallParamPack<ForallParam>,
concepts::negate<
RAJA::expt::type_traits::is_ForallParamPack_empty<ForallParam>>>
RAJA::expt::type_traits::is_ForallParamPack<ForallParam>>
forall_impl(resources::Omp omp_res,
const omp_target_parallel_for_exec<ThreadsPerTeam>& p,
Iterable&& iter,
Func&& loop_body,
ForallParam f_params)
{
using EXEC_POL = camp::decay<decltype(p)>;

RAJA::expt::ParamMultiplexer::parampack_init(p, f_params);
RAJA_OMP_DECLARE_REDUCTION_COMBINE;

using Body = typename std::remove_reference<decltype(loop_body)>::type;
Body body = loop_body;

RAJA_EXTRACT_BED_IT(iter);

// Reset if exceed CUDA threads per block limit.
int tperteam = ThreadsPerTeam;
if (tperteam > omp::MAXNUMTHREADS)
{
tperteam = omp::MAXNUMTHREADS;
}

// calculate number of teams based on user defined threads per team
// datasize is distance between begin() and end() of iterable
auto numteams = RAJA_DIVIDE_CEILING_INT(distance_it, tperteam);
if (numteams > tperteam)
{
// Omp target reducers will write team # results, into Threads-sized array.
// Need to insure NumTeams <= Threads to prevent array out of bounds access.
numteams = tperteam;
}

// thread_limit(tperteam) unused due to XL seg fault (when tperteam !=
// distance)
auto i = distance_it;

#pragma omp target teams distribute parallel for num_teams(numteams) \
schedule(static, 1) map(to \
: body, begin_it) reduction(combine \
: f_params)
for (i = 0; i < distance_it; ++i)
constexpr bool is_forall_param_empty =
RAJA::expt::type_traits::is_ForallParamPack_empty<ForallParam>::value;
if constexpr (!is_forall_param_empty)
{
Body ib = body;
RAJA::expt::invoke_body(f_params, ib, begin_it[i]);
RAJA::expt::ParamMultiplexer::parampack_init(p, f_params);
}

RAJA::expt::ParamMultiplexer::parampack_resolve(p, f_params);

return resources::EventProxy<resources::Omp>(omp_res);
}

template<size_t ThreadsPerTeam,
typename Iterable,
typename Func,
typename ForallParam>
RAJA_INLINE concepts::enable_if_t<
resources::EventProxy<resources::Omp>,
RAJA::expt::type_traits::is_ForallParamPack<ForallParam>,
RAJA::expt::type_traits::is_ForallParamPack_empty<ForallParam>>
forall_impl(resources::Omp omp_res,
const omp_target_parallel_for_exec<ThreadsPerTeam>&,
Iterable&& iter,
Func&& loop_body,
ForallParam)
{
using Body = typename std::remove_reference<decltype(loop_body)>::type;
Body body = loop_body;

Expand All @@ -133,14 +79,31 @@ forall_impl(resources::Omp omp_res,
// thread_limit(tperteam) unused due to XL seg fault (when tperteam !=
// distance)
auto i = distance_it;

if constexpr (is_forall_param_empty)
{
#pragma omp target teams distribute parallel for num_teams(numteams) \
schedule(static, 1) map(to \
: body, begin_it)
for (i = 0; i < distance_it; ++i)
for (i = 0; i < distance_it; ++i)
{
Body ib = body;
ib(begin_it[i]);
}
}
else
{
Body ib = body;
ib(begin_it[i]);
RAJA_OMP_DECLARE_REDUCTION_COMBINE
#pragma omp target teams distribute parallel for num_teams(numteams) \
schedule(static, 1) map(to \
: body, begin_it) reduction(combine \
: f_params)
for (i = 0; i < distance_it; ++i)
{
Body ib = body;
RAJA::expt::invoke_body(f_params, ib, begin_it[i]);
}

RAJA::expt::ParamMultiplexer::parampack_resolve(p, f_params);
}

return resources::EventProxy<resources::Omp>(omp_res);
Expand All @@ -149,63 +112,50 @@ forall_impl(resources::Omp omp_res,
template<typename Iterable, typename Func, typename ForallParam>
RAJA_INLINE concepts::enable_if_t<
resources::EventProxy<resources::Omp>,
RAJA::expt::type_traits::is_ForallParamPack<ForallParam>,
concepts::negate<
RAJA::expt::type_traits::is_ForallParamPack_empty<ForallParam>>>
RAJA::expt::type_traits::is_ForallParamPack<ForallParam>>
forall_impl(resources::Omp omp_res,
const omp_target_parallel_for_exec_nt& p,
Iterable&& iter,
Func&& loop_body,
ForallParam f_params)
{
using EXEC_POL = camp::decay<decltype(p)>;

RAJA::expt::ParamMultiplexer::parampack_init(p, f_params);
RAJA_OMP_DECLARE_REDUCTION_COMBINE;
constexpr bool is_forall_param_empty =
RAJA::expt::type_traits::is_ForallParamPack_empty<ForallParam>::value;
if constexpr (!is_forall_param_empty)
{
RAJA::expt::ParamMultiplexer::parampack_init(p, f_params);
}

using Body = typename std::remove_reference<decltype(loop_body)>::type;
Body body = loop_body;

RAJA_EXTRACT_BED_IT(iter);

if constexpr (!is_forall_param_empty)
{
RAJA_OMP_DECLARE_REDUCTION_COMBINE;
#pragma omp target teams distribute parallel for schedule(static, 1) \
firstprivate(body, begin_it) reduction(combine \
: f_params)
for (decltype(distance_it) i = 0; i < distance_it; ++i)
{
Body ib = body;
RAJA::expt::invoke_body(f_params, ib, begin_it[i]);
}

RAJA::expt::ParamMultiplexer::parampack_resolve(p, f_params);

return resources::EventProxy<resources::Omp>(omp_res);
}

template<typename Iterable, typename Func, typename ForallParam>
RAJA_INLINE concepts::enable_if_t<
resources::EventProxy<resources::Omp>,
RAJA::expt::type_traits::is_ForallParamPack<ForallParam>,
RAJA::expt::type_traits::is_ForallParamPack_empty<ForallParam>>
forall_impl(resources::Omp omp_res,
const omp_target_parallel_for_exec_nt&,
Iterable&& iter,
Func&& loop_body,
ForallParam)
{
using Body = typename std::remove_reference<decltype(loop_body)>::type;
Body body = loop_body;

RAJA_EXTRACT_BED_IT(iter);
for (decltype(distance_it) i = 0; i < distance_it; ++i)
{
Body ib = body;
RAJA::expt::invoke_body(f_params, ib, begin_it[i]);
}

RAJA::expt::ParamMultiplexer::parampack_resolve(p, f_params);
}
else
{
#pragma omp target teams distribute parallel for schedule(static, 1) \
firstprivate(body, begin_it)
for (decltype(distance_it) i = 0; i < distance_it; ++i)
{
Body ib = body;
ib(begin_it[i]);
for (decltype(distance_it) i = 0; i < distance_it; ++i)
{
Body ib = body;
ib(begin_it[i]);
}
}

return resources::EventProxy<resources::Omp>(omp_res);
}

Expand Down
52 changes: 19 additions & 33 deletions include/RAJA/policy/sequential/launch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,57 +39,43 @@ struct LaunchExecute<RAJA::null_launch_t>
template<>
struct LaunchExecute<RAJA::seq_launch_t>
{

template<typename BODY, typename ReduceParams>
static concepts::enable_if_t<
resources::EventProxy<resources::Resource>,
RAJA::expt::type_traits::is_ForallParamPack<ReduceParams>,
RAJA::expt::type_traits::is_ForallParamPack_empty<ReduceParams>>
exec(RAJA::resources::Resource res,
LaunchParams const& params,
BODY const& body,
ReduceParams& RAJA_UNUSED_ARG(ReduceParams))
{

LaunchContext ctx;

char* kernel_local_mem = new char[params.shared_mem_size];
ctx.shared_mem_ptr = kernel_local_mem;

body(ctx);

delete[] kernel_local_mem;
ctx.shared_mem_ptr = nullptr;

return resources::EventProxy<resources::Resource>(res);
}

template<typename BODY, typename ReduceParams>
static concepts::enable_if_t<
resources::EventProxy<resources::Resource>,
RAJA::expt::type_traits::is_ForallParamPack<ReduceParams>,
concepts::negate<
RAJA::expt::type_traits::is_ForallParamPack_empty<ReduceParams>>>
RAJA::expt::type_traits::is_ForallParamPack<ReduceParams>>
exec(RAJA::resources::Resource res,
LaunchParams const& launch_params,
BODY const& body,
ReduceParams& launch_reducers)
{
using EXEC_POL = RAJA::seq_exec;
EXEC_POL pol {};

expt::ParamMultiplexer::parampack_init(pol, launch_reducers);
constexpr bool is_parampack_empty =
RAJA::expt::type_traits::is_ForallParamPack_empty<ReduceParams>::value;
if constexpr (!is_parampack_empty)
{
expt::ParamMultiplexer::parampack_init(pol, launch_reducers);
}

LaunchContext ctx;
char* kernel_local_mem = new char[launch_params.shared_mem_size];
ctx.shared_mem_ptr = kernel_local_mem;

expt::invoke_body(launch_reducers, body, ctx);
if constexpr (!is_parampack_empty)
{
expt::invoke_body(launch_reducers, body, ctx);
}
else
{
body(ctx);
}

delete[] kernel_local_mem;
ctx.shared_mem_ptr = nullptr;

expt::ParamMultiplexer::parampack_resolve(pol, launch_reducers);
if constexpr (!is_parampack_empty)
{
expt::ParamMultiplexer::parampack_resolve(pol, launch_reducers);
}

return resources::EventProxy<resources::Resource>(res);
}
Expand Down
Loading