Skip to content

[CUB] Fully support warp redux #9441

Open
fbusato wants to merge 2 commits into
NVIDIA:mainfrom
fbusato:redux-fp-max-min
Open

[CUB] Fully support warp redux #9441
fbusato wants to merge 2 commits into
NVIDIA:mainfrom
fbusato:redux-fp-max-min

Conversation

@fbusato

@fbusato fbusato commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Description

I recently saw at least two PRs related to the warp reduction with specific dispatch for redux. Considering that the current dispatch logic is quite messy and there are additional optimizations, I create this PR:

  • Moved warp redux code to a dedicated header.
  • Added sm100a/f redux.sync.min/max.f32 for float
  • __half and __nv_bfloat16 min/max are also handled by promoted them to float.
  • The new warp_redux function returns cuda::std::optional to greatly simplify the calling dispatch.
  • Provided specific traits to understand when redux is valid.

Todo:

  • sass comparison/benchmarks

@fbusato fbusato self-assigned this Jun 12, 2026
@fbusato fbusato requested a review from a team as a code owner June 12, 2026 23:45
@fbusato fbusato added the cub For all items related to CUB label Jun 12, 2026
@fbusato fbusato added this to CCCL Jun 12, 2026
@fbusato fbusato requested a review from elstehle June 12, 2026 23:45
@github-project-automation github-project-automation Bot moved this to Todo in CCCL Jun 12, 2026
@cccl-authenticator-app cccl-authenticator-app Bot moved this from Todo to In Review in CCCL Jun 12, 2026
@coderabbitai

coderabbitai Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note: CodeRabbit is enabled on this repository as a convenience for maintainers
and contributors. Use your best judgment when considering its review comments and
suggestions — a suggested change may be inadequate, unnecessary, or safe to ignore.
Contributors are not expected to address every comment. Human reviews are what
ultimately matter for merging.

Summary

This pull request refactors and extends warp-level reduction (redux) support in CUB by consolidating redux implementations into a dedicated header and adding comprehensive SM100a/F support for floating-point min/max operations.

Key Changes

New Dedicated Redux Header (warp_redux.cuh)

  • Introduces a centralized cub::detail::warp_redux dispatcher that returns ::cuda::std::optional<T>, simplifying optional reduction semantics across the codebase
  • Implements SM80 backend via warp_redux_sm80, which applies CUDA __reduce_*_sync primitives after promoting inputs to appropriate integer types
  • Adds SM100a/F backend support via warp_redux_sm100af with custom PTX-based float min/max helpers (redux_sm100af_min_ptx and redux_sm100af_max_ptx)
  • Provides compile-time trait predicates:
    • is_warp_redux_op_supported_sm80<Op, T, ReduceOp> for SM80 capability detection
    • is_warp_redux_op_supported_sm100af<Op, T, ReduceOp> for SM100a/F capability detection
    • is_warp_redux_op_supported<Op, T> for general redux support

SM100a/F Support for Floating-Point Min/Max

  • Adds native PTX-based implementations for float min/max reductions on SM100a/F architectures
  • Implements type promotion for __half and __nv_bfloat16 to float for reduction operations, allowing these types to leverage native hardware min/max support

Integration Across Warp Reduction Modules

  • warp_reduce_shfl.cuh: Fast-path reduction now conditionally dispatches to warp_redux for supported operator/type pairs, with fallback to existing shuffle-based implementation
  • warp_reduce_batched_wspro.cuh: Updated ReduceRedux function to use warp_redux dispatcher; function signature changed from void to [[nodiscard]] bool to handle optional result checking
  • look_ahead.cuh: SM80-specific incremental lookahead reduction now uses warp_redux_sm80 when supported, with fallback to traditional warp reduce

Trait Cleanup

  • Removed is_redux_enabled_cuda_operator trait from cub/cub/thread/thread_operators.cuh in favor of the more granular backend-specific traits in the new header

Testing

  • Added comprehensive test coverage in catch2_test_warp_reduce.cu with a new floating_point_redux_type_list containing float, __half, and __nv_bfloat16
  • New test case WarpReduce::Max/Min, floating-point redux types validates floating-point redux functionality with exact verification for float and standard verification for promoted types

Impact

This refactoring consolidates redux logic, reduces code duplication across warp reduction modules, and enables efficient hardware-native floating-point min/max operations on SM100a/F architectures while maintaining backward compatibility through fallback mechanisms.

Walkthrough

Warp-level REDUX reduction dispatch for SM80 and SM100af architectures. New warp_redux.cuh provides trait-gated, architecture-specific reductions; look_ahead, batched_wspro, and shfl integrate the dispatcher. is_redux_enabled_cuda_operator is removed. Floating-point redux tests are added.

Changes

Warp REDUX dispatch and consumers

Layer / File(s) Summary
warp_redux.cuh: SM80/SM100af dispatch and implementations
cub/cub/warp/specializations/warp_redux.cuh
New file defines compile-time support traits (is_warp_redux_op_supported_sm80, is_warp_redux_op_supported_sm100af, is_warp_redux_op_supported), implements warp_redux_sm80 using CUDA _reduce*_sync intrinsics with input type promotion, implements warp_redux_sm100af using inline PTX float min/max helpers guarded by PTX ISA, and exports warp_redux dispatcher returning optional with NV_IF_TARGET backend selection.
look_ahead.cuh: integrate warp_redux_sm80
cub/cub/detail/warpspeed/look_ahead.cuh
Adds warp_redux.cuh and is_same.h includes. Updates warpIncrementalLookahead SM80 path to dispatch via is_warp_redux_op_supported_sm80 and call warp_redux_sm80 when supported, falling back to existing warp_reduce_t path for unsupported cases.
warp_reduce_batched_wspro.cuh: integrate generic warp_redux
cub/cub/warp/specializations/warp_reduce_batched_wspro.cuh
Adds warp_redux.cuh include. ReduceRedux changes from void to [[nodiscard]] bool. Reduce dispatch via is_warp_redux_op_supported with early return on ReduceRedux bool result. Replaces reduce_op_sync with warp_redux call; checks optional result before dereferencing and writing intermediate_outputs, returns false on failure or true on success.
warp_reduce_shfl.cuh: replace reduce_op_sync with warp_redux
cub/cub/warp/specializations/warp_reduce_shfl.cuh
Adds warp_redux.cuh include, removes internal reduce_op_sync helper. Updates ALL_LANES_VALID fast-path to dispatch via is_warp_redux_op_supported and call warp_redux, returning optional result when available.
thread_operators.cuh: remove is_redux_enabled_cuda_operator
cub/cub/thread/thread_operators.cuh
Removes is_redux_enabled_cuda_operator trait and TODO comment; functionality replaced by architecture-specific is_warp_redux_op_supported predicates.
catch2_test_warp_reduce.cu: floating-point redux tests
cub/test/warp/catch2_test_warp_reduce.cu
Adds floating_point_redux_type_list containing float and conditionally __half and __nv_bfloat16. New test case WarpReduce::Max/Min floating-point redux types runs over new type list, uses verify_results_exact for float and verify_results for other types.

Suggested reviewers

  • elstehle
  • shwina
  • andralex
  • miscco
  • NaderAlAwar

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (4)
cub/cub/warp/specializations/warp_redux.cuh (4)

1-2: 📐 Maintainability & Code Quality | 💤 Low value

suggestion: SPDX license identifier should be BSD-3-Clause for full compliance.


65-73: 📐 Maintainability & Code Quality | 💤 Low value

suggestion: Per coding guidelines, unmodified parameters and local variables should be const. mask parameter and value local are not modified.

 template <typename T, typename ReductionOp>
 [[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE T
-warp_redux_sm80(const T input, const ::cuda::std::uint32_t mask, ReductionOp)
+warp_redux_sm80(const T input, const ::cuda::std::uint32_t mask, ReductionOp) noexcept
 {
   static_assert(is_warp_redux_op_supported_sm80<ReductionOp, T>, "Reduction operator not supported");
   _CCCL_ASSERT(mask != 0, "Mask must not be 0");
 
-  using promotion_t = ::cuda::std::conditional_t<::cuda::std::is_signed_v<T>, int, unsigned>;
-  const auto value  = static_cast<promotion_t>(input);
+  using promotion_t  = ::cuda::std::conditional_t<::cuda::std::is_signed_v<T>, int, unsigned>;
+  const promotion_t value = static_cast<promotion_t>(input);

Source: Coding guidelines


110-121: 🚀 Performance & Scalability | 💤 Low value

suggestion: The macro-generated PTX functions use asm volatile but the reduction is a pure read operation with no memory side effects. Consider asm without volatile to allow compiler optimization, or add a brief comment explaining why volatile is necessary (e.g., if required for synchronization semantics).


131-149: 📐 Maintainability & Code Quality | 💤 Low value

suggestion: Per guidelines, value and result locals should be const where possible. value is not modified after initialization.

 template <typename T, typename ReductionOp>
 [[nodiscard]] _CCCL_DEVICE_API
-_CCCL_FORCEINLINE T warp_redux_sm100af(const T input, const ::cuda::std::uint32_t mask, ReductionOp)
+_CCCL_FORCEINLINE T warp_redux_sm100af(const T input, const ::cuda::std::uint32_t mask, ReductionOp) noexcept
 {
   static_assert(is_warp_redux_op_supported_sm100af<ReductionOp, T>, "Reduction operator not supported");
   _CCCL_ASSERT(mask != 0, "Mask must not be 0");
 
   const float value = ::cuda::std::__fp_cast<float>(input);
-  float result;
+  const float result = [&]() {
   if constexpr (is_cuda_minimum_v<ReductionOp, T>)
   {
-    result = cub::detail::redux_sm100af_min_ptx(value, mask);
+    return ::cub::detail::redux_sm100af_min_ptx(value, mask);
   }
   else
   {
-    result = cub::detail::redux_sm100af_max_ptx(value, mask);
+    return ::cub::detail::redux_sm100af_max_ptx(value, mask);
   }
+  }();
   return ::cuda::std::__fp_cast<T>(result);
 }

Alternatively, keep the current structure but add const where applicable and fully qualify cub::detail:: as ::cub::detail::.

Source: Coding guidelines


ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 66046a44-cb32-43c6-8ce3-c6b5f927aa61

📥 Commits

Reviewing files that changed from the base of the PR and between 8c0e6cb and 2261919.

📒 Files selected for processing (6)
  • cub/cub/detail/warpspeed/look_ahead.cuh
  • cub/cub/thread/thread_operators.cuh
  • cub/cub/warp/specializations/warp_reduce_batched_wspro.cuh
  • cub/cub/warp/specializations/warp_reduce_shfl.cuh
  • cub/cub/warp/specializations/warp_redux.cuh
  • cub/test/warp/catch2_test_warp_reduce.cu
💤 Files with no reviewable changes (1)
  • cub/cub/thread/thread_operators.cuh

Comment on lines +89 to 95
if constexpr (is_warp_redux_op_supported<ReductionOp, T> && redux_performs_better)
{
NV_IF_TARGET(NV_PROVIDES_SM_80, ({
ReduceRedux<ToBlocked>(inputs, outputs, reduction_op);
return;
}))
if (ReduceRedux<ToBlocked>(inputs, outputs, reduction_op))
{
return;
}
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟠 Major | ⚡ Quick win

suggestion: Per coding guidelines, trait should be fully qualified: ::cub::detail::is_warp_redux_op_supported.

-    if constexpr (is_warp_redux_op_supported<ReductionOp, T> && redux_performs_better)
+    if constexpr (::cub::detail::is_warp_redux_op_supported<ReductionOp, T> && redux_performs_better)

Source: Coding guidelines

Comment on lines +127 to 137
const auto result = cub::detail::warp_redux(inputs[i], reduce_mask, reduction_op);
if (!result)
{
return false;
}
const auto out_lane = ToBlocked ? i / max_out_per_thread : i % LogicalWarpThreads;
const auto out_idx = ToBlocked ? i % max_out_per_thread : i / LogicalWarpThreads;
if (logical_lane_id == out_lane)
{
intermediate_outputs[out_idx] = result;
intermediate_outputs[out_idx] = *result;
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟠 Major | ⚡ Quick win

suggestion: Per coding guidelines, unmodified variables should be const. Also fully qualify the function call.

-        const auto result = cub::detail::warp_redux(inputs[i], reduce_mask, reduction_op);
+        const auto result = ::cub::detail::warp_redux(inputs[i], reduce_mask, reduction_op);
         if (!result)
         {
           return false;
         }
-        const auto out_lane = ToBlocked ? i / max_out_per_thread : i % LogicalWarpThreads;
-        const auto out_idx  = ToBlocked ? i % max_out_per_thread : i / LogicalWarpThreads;
+        const int out_lane = ToBlocked ? i / max_out_per_thread : i % LogicalWarpThreads;
+        const int out_idx  = ToBlocked ? i % max_out_per_thread : i / LogicalWarpThreads;

The out_lane and out_idx are already const auto, which is fine, but explicit int type might be clearer. The main issue is the missing :: prefix on the function call.

Source: Coding guidelines

Comment on lines +464 to 470
if constexpr (ALL_LANES_VALID && is_warp_redux_op_supported<ReductionOp, T>)
{
NV_IF_TARGET(NV_PROVIDES_SM_80, (return reduce_op_sync(input, member_mask, reduction_op);))
if (const auto output = cub::detail::warp_redux(input, member_mask, reduction_op))
{
return *output;
}
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟠 Major | ⚡ Quick win

suggestion: Per coding guidelines, trait and function call should be fully qualified with :: prefix.

-    if constexpr (ALL_LANES_VALID && is_warp_redux_op_supported<ReductionOp, T>)
+    if constexpr (ALL_LANES_VALID && ::cub::detail::is_warp_redux_op_supported<ReductionOp, T>)
     {
-      if (const auto output = cub::detail::warp_redux(input, member_mask, reduction_op))
+      if (const auto output = ::cub::detail::warp_redux(input, member_mask, reduction_op))
       {
         return *output;
       }
     }

Source: Coding guidelines

Comment on lines +156 to +175
template <typename T, typename ReductionOp>
[[nodiscard]] _CCCL_DEVICE_API _CCCL_FORCEINLINE constexpr ::cuda::std::optional<T>
warp_redux(const T input, const ::cuda::std::uint32_t mask, ReductionOp reduction_op)
{
static_assert(is_warp_redux_op_supported<ReductionOp, T>, "Reduction operator not supported");
if constexpr (is_warp_redux_op_supported_sm80<ReductionOp, T>)
{
NV_IF_TARGET(NV_PROVIDES_SM_80, (return cub::detail::warp_redux_sm80(input, mask, reduction_op);))
}
#if __cccl_ptx_isa >= 860
else if constexpr (is_warp_redux_op_supported_sm100af<ReductionOp, T>)
{
NV_IF_TARGET(NV_HAS_FEATURE_SM_100a, (return cub::detail::warp_redux_sm100af(input, mask, reduction_op);))
# if __cccl_ptx_isa >= 880
NV_IF_TARGET(NV_HAS_FEATURE_SM_100f, (return cub::detail::warp_redux_sm100af(input, mask, reduction_op);))
# endif // __cccl_ptx_isa >= 880
}
#endif // __cccl_ptx_isa >= 860
return ::cuda::std::nullopt;
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟠 Major | ⚡ Quick win

important: Per coding guidelines, all calls to free functions must be fully qualified starting from the global namespace. The calls to warp_redux_sm80 and warp_redux_sm100af inside warp_redux should use ::cub::detail:: prefix.

   if constexpr (is_warp_redux_op_supported_sm80<ReductionOp, T>)
   {
-    NV_IF_TARGET(NV_PROVIDES_SM_80, (return cub::detail::warp_redux_sm80(input, mask, reduction_op);))
+    NV_IF_TARGET(NV_PROVIDES_SM_80, (return ::cub::detail::warp_redux_sm80(input, mask, reduction_op);))
   }
 `#if` __cccl_ptx_isa >= 860
   else if constexpr (is_warp_redux_op_supported_sm100af<ReductionOp, T>)
   {
-    NV_IF_TARGET(NV_HAS_FEATURE_SM_100a, (return cub::detail::warp_redux_sm100af(input, mask, reduction_op);))
+    NV_IF_TARGET(NV_HAS_FEATURE_SM_100a, (return ::cub::detail::warp_redux_sm100af(input, mask, reduction_op);))
 #  if __cccl_ptx_isa >= 880
-    NV_IF_TARGET(NV_HAS_FEATURE_SM_100f, (return cub::detail::warp_redux_sm100af(input, mask, reduction_op);))
+    NV_IF_TARGET(NV_HAS_FEATURE_SM_100f, (return ::cub::detail::warp_redux_sm100af(input, mask, reduction_op);))
 #  endif // __cccl_ptx_isa >= 880
   }
 `#endif` // __cccl_ptx_isa >= 860

Source: Coding guidelines

Comment on lines +163 to +170
NV_IF_TARGET(NV_PROVIDES_SM_80, (return cub::detail::warp_redux_sm80(input, mask, reduction_op);))
}
#if __cccl_ptx_isa >= 860
else if constexpr (is_warp_redux_op_supported_sm100af<ReductionOp, T>)
{
NV_IF_TARGET(NV_HAS_FEATURE_SM_100a, (return cub::detail::warp_redux_sm100af(input, mask, reduction_op);))
# if __cccl_ptx_isa >= 880
NV_IF_TARGET(NV_HAS_FEATURE_SM_100f, (return cub::detail::warp_redux_sm100af(input, mask, reduction_op);))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📐 Maintainability & Code Quality | 🟠 Major | ⚡ Quick win

suggestion: Multiple call sites across warp_redux.cuh, look_ahead.cuh, warp_reduce_batched_wspro.cuh, and warp_reduce_shfl.cuh reference cub::detail:: functions/traits without the leading :: global namespace qualifier. Per coding guidelines, all free function calls and trait references must be fully qualified starting from global namespace (e.g., ::cub::detail::warp_redux(...)). This ensures ADL doesn't cause unexpected behavior and maintains consistency across the codebase.

Source: Coding guidelines

Comment on lines +297 to +303
if constexpr (cuda::std::is_same_v<T, float>)
{
verify_results_exact(h_out, d_out);
}
else
{
verify_results(h_out, d_out);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎯 Functional Correctness | 🟠 Major | ⚡ Quick win

important: Line 297 limits exact validation to float, and Lines 301-303 use tolerance-based verify_results for __half/__nv_bfloat16. For min/max redux, this can miss wrong-extrema regressions because approximate checks may still pass. Use verify_results_exact(h_out, d_out) for all types in this redux-focused test so the selected extrema are validated strictly.

else
{
_CCCL_UNREACHABLE();
return T{};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an optional as well? Its packed into one after return anyway

@github-actions

Copy link
Copy Markdown
Contributor

😬 CI Workflow Results

🟥 Finished in 2h 32m: Pass: 85%/287 | Total: 10d 06h | Max: 2h 31m | Hits: 13%/1114786

See results here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cub For all items related to CUB

Projects

Status: In Review

Development

Successfully merging this pull request may close these issues.

2 participants