[CUB] Fully support warp redux #9441
Conversation
SummaryThis pull request refactors and extends warp-level reduction (redux) support in CUB by consolidating redux implementations into a dedicated header and adding comprehensive SM100a/F support for floating-point min/max operations. Key ChangesNew Dedicated Redux Header (
|
| Layer / File(s) | Summary |
|---|---|
warp_redux.cuh: SM80/SM100af dispatch and implementations cub/cub/warp/specializations/warp_redux.cuh |
New file defines compile-time support traits (is_warp_redux_op_supported_sm80, is_warp_redux_op_supported_sm100af, is_warp_redux_op_supported), implements warp_redux_sm80 using CUDA _reduce*_sync intrinsics with input type promotion, implements warp_redux_sm100af using inline PTX float min/max helpers guarded by PTX ISA, and exports warp_redux dispatcher returning optional with NV_IF_TARGET backend selection. |
look_ahead.cuh: integrate warp_redux_sm80 cub/cub/detail/warpspeed/look_ahead.cuh |
Adds warp_redux.cuh and is_same.h includes. Updates warpIncrementalLookahead SM80 path to dispatch via is_warp_redux_op_supported_sm80 and call warp_redux_sm80 when supported, falling back to existing warp_reduce_t path for unsupported cases. |
warp_reduce_batched_wspro.cuh: integrate generic warp_redux cub/cub/warp/specializations/warp_reduce_batched_wspro.cuh |
Adds warp_redux.cuh include. ReduceRedux changes from void to [[nodiscard]] bool. Reduce dispatch via is_warp_redux_op_supported with early return on ReduceRedux bool result. Replaces reduce_op_sync with warp_redux call; checks optional result before dereferencing and writing intermediate_outputs, returns false on failure or true on success. |
warp_reduce_shfl.cuh: replace reduce_op_sync with warp_redux cub/cub/warp/specializations/warp_reduce_shfl.cuh |
Adds warp_redux.cuh include, removes internal reduce_op_sync helper. Updates ALL_LANES_VALID fast-path to dispatch via is_warp_redux_op_supported and call warp_redux, returning optional result when available. |
thread_operators.cuh: remove is_redux_enabled_cuda_operator cub/cub/thread/thread_operators.cuh |
Removes is_redux_enabled_cuda_operator trait and TODO comment; functionality replaced by architecture-specific is_warp_redux_op_supported predicates. |
catch2_test_warp_reduce.cu: floating-point redux tests cub/test/warp/catch2_test_warp_reduce.cu |
Adds floating_point_redux_type_list containing float and conditionally __half and __nv_bfloat16. New test case WarpReduce::Max/Min floating-point redux types runs over new type list, uses verify_results_exact for float and verify_results for other types. |
Suggested reviewers
- elstehle
- shwina
- andralex
- miscco
- NaderAlAwar
Comment @coderabbitai help to get the list of available commands and usage tips.
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (4)
cub/cub/warp/specializations/warp_redux.cuh (4)
1-2: 📐 Maintainability & Code Quality | 💤 Low valuesuggestion: SPDX license identifier should be
BSD-3-Clausefor full compliance.
65-73: 📐 Maintainability & Code Quality | 💤 Low valuesuggestion: Per coding guidelines, unmodified parameters and local variables should be
const.maskparameter andvaluelocal are not modified.template <typename T, typename ReductionOp> [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T -warp_redux_sm80(const T input, const ::cuda::std::uint32_t mask, ReductionOp) +warp_redux_sm80(const T input, const ::cuda::std::uint32_t mask, ReductionOp) noexcept { static_assert(is_warp_redux_op_supported_sm80<ReductionOp, T>, "Reduction operator not supported"); _CCCL_ASSERT(mask != 0, "Mask must not be 0"); - using promotion_t = ::cuda::std::conditional_t<::cuda::std::is_signed_v<T>, int, unsigned>; - const auto value = static_cast<promotion_t>(input); + using promotion_t = ::cuda::std::conditional_t<::cuda::std::is_signed_v<T>, int, unsigned>; + const promotion_t value = static_cast<promotion_t>(input);Source: Coding guidelines
110-121: 🚀 Performance & Scalability | 💤 Low valuesuggestion: The macro-generated PTX functions use
asm volatilebut the reduction is a pure read operation with no memory side effects. Considerasmwithoutvolatileto allow compiler optimization, or add a brief comment explaining whyvolatileis necessary (e.g., if required for synchronization semantics).
131-149: 📐 Maintainability & Code Quality | 💤 Low valuesuggestion: Per guidelines,
valueandresultlocals should beconstwhere possible.valueis not modified after initialization.template <typename T, typename ReductionOp> [[nodiscard]] _CCCL_DEVICE_API -_CCCL_FORCEINLINE T warp_redux_sm100af(const T input, const ::cuda::std::uint32_t mask, ReductionOp) +_CCCL_FORCEINLINE T warp_redux_sm100af(const T input, const ::cuda::std::uint32_t mask, ReductionOp) noexcept { static_assert(is_warp_redux_op_supported_sm100af<ReductionOp, T>, "Reduction operator not supported"); _CCCL_ASSERT(mask != 0, "Mask must not be 0"); const float value = ::cuda::std::__fp_cast<float>(input); - float result; + const float result = [&]() { if constexpr (is_cuda_minimum_v<ReductionOp, T>) { - result = cub::detail::redux_sm100af_min_ptx(value, mask); + return ::cub::detail::redux_sm100af_min_ptx(value, mask); } else { - result = cub::detail::redux_sm100af_max_ptx(value, mask); + return ::cub::detail::redux_sm100af_max_ptx(value, mask); } + }(); return ::cuda::std::__fp_cast<T>(result); }Alternatively, keep the current structure but add
constwhere applicable and fully qualifycub::detail::as::cub::detail::.Source: Coding guidelines
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 66046a44-cb32-43c6-8ce3-c6b5f927aa61
📒 Files selected for processing (6)
cub/cub/detail/warpspeed/look_ahead.cuhcub/cub/thread/thread_operators.cuhcub/cub/warp/specializations/warp_reduce_batched_wspro.cuhcub/cub/warp/specializations/warp_reduce_shfl.cuhcub/cub/warp/specializations/warp_redux.cuhcub/test/warp/catch2_test_warp_reduce.cu
💤 Files with no reviewable changes (1)
- cub/cub/thread/thread_operators.cuh
| if constexpr (is_warp_redux_op_supported<ReductionOp, T> && redux_performs_better) | ||
| { | ||
| NV_IF_TARGET(NV_PROVIDES_SM_80, ({ | ||
| ReduceRedux<ToBlocked>(inputs, outputs, reduction_op); | ||
| return; | ||
| })) | ||
| if (ReduceRedux<ToBlocked>(inputs, outputs, reduction_op)) | ||
| { | ||
| return; | ||
| } | ||
| } |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟠 Major | ⚡ Quick win
suggestion: Per coding guidelines, trait should be fully qualified: ::cub::detail::is_warp_redux_op_supported.
- if constexpr (is_warp_redux_op_supported<ReductionOp, T> && redux_performs_better)
+ if constexpr (::cub::detail::is_warp_redux_op_supported<ReductionOp, T> && redux_performs_better)Source: Coding guidelines
| const auto result = cub::detail::warp_redux(inputs[i], reduce_mask, reduction_op); | ||
| if (!result) | ||
| { | ||
| return false; | ||
| } | ||
| const auto out_lane = ToBlocked ? i / max_out_per_thread : i % LogicalWarpThreads; | ||
| const auto out_idx = ToBlocked ? i % max_out_per_thread : i / LogicalWarpThreads; | ||
| if (logical_lane_id == out_lane) | ||
| { | ||
| intermediate_outputs[out_idx] = result; | ||
| intermediate_outputs[out_idx] = *result; | ||
| } |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟠 Major | ⚡ Quick win
suggestion: Per coding guidelines, unmodified variables should be const. Also fully qualify the function call.
- const auto result = cub::detail::warp_redux(inputs[i], reduce_mask, reduction_op);
+ const auto result = ::cub::detail::warp_redux(inputs[i], reduce_mask, reduction_op);
if (!result)
{
return false;
}
- const auto out_lane = ToBlocked ? i / max_out_per_thread : i % LogicalWarpThreads;
- const auto out_idx = ToBlocked ? i % max_out_per_thread : i / LogicalWarpThreads;
+ const int out_lane = ToBlocked ? i / max_out_per_thread : i % LogicalWarpThreads;
+ const int out_idx = ToBlocked ? i % max_out_per_thread : i / LogicalWarpThreads;The out_lane and out_idx are already const auto, which is fine, but explicit int type might be clearer. The main issue is the missing :: prefix on the function call.
Source: Coding guidelines
| if constexpr (ALL_LANES_VALID && is_warp_redux_op_supported<ReductionOp, T>) | ||
| { | ||
| NV_IF_TARGET(NV_PROVIDES_SM_80, (return reduce_op_sync(input, member_mask, reduction_op);)) | ||
| if (const auto output = cub::detail::warp_redux(input, member_mask, reduction_op)) | ||
| { | ||
| return *output; | ||
| } | ||
| } |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟠 Major | ⚡ Quick win
suggestion: Per coding guidelines, trait and function call should be fully qualified with :: prefix.
- if constexpr (ALL_LANES_VALID && is_warp_redux_op_supported<ReductionOp, T>)
+ if constexpr (ALL_LANES_VALID && ::cub::detail::is_warp_redux_op_supported<ReductionOp, T>)
{
- if (const auto output = cub::detail::warp_redux(input, member_mask, reduction_op))
+ if (const auto output = ::cub::detail::warp_redux(input, member_mask, reduction_op))
{
return *output;
}
}Source: Coding guidelines
| template <typename T, typename ReductionOp> | ||
| [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE constexpr ::cuda::std::optional<T> | ||
| warp_redux(const T input, const ::cuda::std::uint32_t mask, ReductionOp reduction_op) | ||
| { | ||
| static_assert(is_warp_redux_op_supported<ReductionOp, T>, "Reduction operator not supported"); | ||
| if constexpr (is_warp_redux_op_supported_sm80<ReductionOp, T>) | ||
| { | ||
| NV_IF_TARGET(NV_PROVIDES_SM_80, (return cub::detail::warp_redux_sm80(input, mask, reduction_op);)) | ||
| } | ||
| #if __cccl_ptx_isa >= 860 | ||
| else if constexpr (is_warp_redux_op_supported_sm100af<ReductionOp, T>) | ||
| { | ||
| NV_IF_TARGET(NV_HAS_FEATURE_SM_100a, (return cub::detail::warp_redux_sm100af(input, mask, reduction_op);)) | ||
| # if __cccl_ptx_isa >= 880 | ||
| NV_IF_TARGET(NV_HAS_FEATURE_SM_100f, (return cub::detail::warp_redux_sm100af(input, mask, reduction_op);)) | ||
| # endif // __cccl_ptx_isa >= 880 | ||
| } | ||
| #endif // __cccl_ptx_isa >= 860 | ||
| return ::cuda::std::nullopt; | ||
| } |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟠 Major | ⚡ Quick win
important: Per coding guidelines, all calls to free functions must be fully qualified starting from the global namespace. The calls to warp_redux_sm80 and warp_redux_sm100af inside warp_redux should use ::cub::detail:: prefix.
if constexpr (is_warp_redux_op_supported_sm80<ReductionOp, T>)
{
- NV_IF_TARGET(NV_PROVIDES_SM_80, (return cub::detail::warp_redux_sm80(input, mask, reduction_op);))
+ NV_IF_TARGET(NV_PROVIDES_SM_80, (return ::cub::detail::warp_redux_sm80(input, mask, reduction_op);))
}
`#if` __cccl_ptx_isa >= 860
else if constexpr (is_warp_redux_op_supported_sm100af<ReductionOp, T>)
{
- NV_IF_TARGET(NV_HAS_FEATURE_SM_100a, (return cub::detail::warp_redux_sm100af(input, mask, reduction_op);))
+ NV_IF_TARGET(NV_HAS_FEATURE_SM_100a, (return ::cub::detail::warp_redux_sm100af(input, mask, reduction_op);))
# if __cccl_ptx_isa >= 880
- NV_IF_TARGET(NV_HAS_FEATURE_SM_100f, (return cub::detail::warp_redux_sm100af(input, mask, reduction_op);))
+ NV_IF_TARGET(NV_HAS_FEATURE_SM_100f, (return ::cub::detail::warp_redux_sm100af(input, mask, reduction_op);))
# endif // __cccl_ptx_isa >= 880
}
`#endif` // __cccl_ptx_isa >= 860Source: Coding guidelines
| NV_IF_TARGET(NV_PROVIDES_SM_80, (return cub::detail::warp_redux_sm80(input, mask, reduction_op);)) | ||
| } | ||
| #if __cccl_ptx_isa >= 860 | ||
| else if constexpr (is_warp_redux_op_supported_sm100af<ReductionOp, T>) | ||
| { | ||
| NV_IF_TARGET(NV_HAS_FEATURE_SM_100a, (return cub::detail::warp_redux_sm100af(input, mask, reduction_op);)) | ||
| # if __cccl_ptx_isa >= 880 | ||
| NV_IF_TARGET(NV_HAS_FEATURE_SM_100f, (return cub::detail::warp_redux_sm100af(input, mask, reduction_op);)) |
There was a problem hiding this comment.
📐 Maintainability & Code Quality | 🟠 Major | ⚡ Quick win
suggestion: Multiple call sites across warp_redux.cuh, look_ahead.cuh, warp_reduce_batched_wspro.cuh, and warp_reduce_shfl.cuh reference cub::detail:: functions/traits without the leading :: global namespace qualifier. Per coding guidelines, all free function calls and trait references must be fully qualified starting from global namespace (e.g., ::cub::detail::warp_redux(...)). This ensures ADL doesn't cause unexpected behavior and maintains consistency across the codebase.
Source: Coding guidelines
| if constexpr (cuda::std::is_same_v<T, float>) | ||
| { | ||
| verify_results_exact(h_out, d_out); | ||
| } | ||
| else | ||
| { | ||
| verify_results(h_out, d_out); |
There was a problem hiding this comment.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
important: Line 297 limits exact validation to float, and Lines 301-303 use tolerance-based verify_results for __half/__nv_bfloat16. For min/max redux, this can miss wrong-extrema regressions because approximate checks may still pass. Use verify_results_exact(h_out, d_out) for all types in this redux-focused test so the selected extrema are validated strictly.
| else | ||
| { | ||
| _CCCL_UNREACHABLE(); | ||
| return T{}; |
There was a problem hiding this comment.
Should this be an optional as well? Its packed into one after return anyway
😬 CI Workflow Results🟥 Finished in 2h 32m: Pass: 85%/287 | Total: 10d 06h | Max: 2h 31m | Hits: 13%/1114786See results here. |
Description
I recently saw at least two PRs related to the warp reduction with specific dispatch for
redux. Considering that the current dispatch logic is quite messy and there are additional optimizations, I create this PR:reduxcode to a dedicated header.sm100a/fredux.sync.min/max.f32forfloat__halfand__nv_bfloat16min/max are also handled by promoted them tofloat.warp_reduxfunction returnscuda::std::optionalto greatly simplify the calling dispatch.reduxis valid.Todo: