From 38658c27b33251a6b47c0b0b7aef315332648500 Mon Sep 17 00:00:00 2001 From: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> Date: Mon, 27 Oct 2025 05:59:14 +0000 Subject: [PATCH 1/3] Remove pingpong scheduler from Hopper mixed dtype grouped gemm tuner Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> --- .../cutlass_kernels/cutlass_heuristic.cpp | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 7c81a5d7b56..d3401aa4518 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -304,27 +304,30 @@ std::vector get_candidate_configs_sm90(CutlassGemmConfig::Can if (has_w4afp8) { bool const has_coop_supported = sm90_supports_coop(tile_config); - std::set mainloop_schedules{MainloopScheduleType::PINGPONG}; - if (has_coop_supported) - { - mainloop_schedules.insert(MainloopScheduleType::COOPERATIVE); - } + + // It seems that ping-pong scheduler will never be selected. + // To shorten the tactic time, remove all alternative options involving ping-pong scheduler. + if (!has_coop_supported) + continue; + // Due to the limitation on the number of registers on SM, + // cooperative scheduler does not support CtaShape128x128x128B. + if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B) + continue; + MainloopScheduleType mainloop_schedule = MainloopScheduleType::COOPERATIVE; auto const epilogue_schedule = EpilogueScheduleType::AUTO; - for (auto const& mainloop_schedule : mainloop_schedules) - { - CutlassGemmConfig candidate( - tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x1x1); - candidate_configs.push_back(candidate); - candidate = CutlassGemmConfig( - tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x1x1); - candidate_configs.push_back(candidate); - candidate = CutlassGemmConfig( - tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x2x1); - candidate_configs.push_back(candidate); - candidate = CutlassGemmConfig( - tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x2x1); - candidate_configs.push_back(candidate); - } + + CutlassGemmConfig candidate( + tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x1x1); + candidate_configs.push_back(candidate); + candidate = CutlassGemmConfig( + tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x1x1); + candidate_configs.push_back(candidate); + candidate = CutlassGemmConfig( + tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x2x1); + candidate_configs.push_back(candidate); + candidate = CutlassGemmConfig( + tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x2x1); + candidate_configs.push_back(candidate); } else { From 896424f463cc5b8b8f396182a6275a1eb790ac68 Mon Sep 17 00:00:00 2001 From: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> Date: Mon, 27 Oct 2025 06:00:50 +0000 Subject: [PATCH 2/3] Fix a perf regression caused by fp8->bf16 scale factor conversion for Hopper MXFP4 x BF16 Grouped GEMM. Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> --- ...a_gmma_rs_warpspecialized_mixed_input_.hpp | 30 ++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp index 0ce601d5b08..8e9b1faae90 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp @@ -899,6 +899,28 @@ struct CollectiveMmaArrayMixedInput< } ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + // Override the FP8 conversion in CUTLASS to enforce the intended compiler behavior. + template + CUTLASS_DEVICE float scale_convertor(T scale) + { + if constexpr (cute::is_same_v) + { + + cutlass::float_ue8m0_t scale_ue8m0 = scale; + + uint32_t temp = 0; + temp = (temp | *reinterpret_cast(&scale_ue8m0)) << 23; + return *reinterpret_cast(&temp); + } + else + { + return static_cast(scale); + } + } + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// Perform a collective-scoped matrix multiply-accumulate /// Consumer Perspective template @@ -1084,12 +1106,12 @@ struct CollectiveMmaArrayMixedInput< if (chunk_id_ == 0) { accum(accum_coord) = intermediate_array[chunk_id_](accum_coord) - * static_cast(tCrS(scale_coord)[0]); + * scale_convertor(tCrS(scale_coord)[0]); } else { accum(accum_coord) = fma(intermediate_array[chunk_id_](accum_coord), - static_cast(tCrS(scale_coord)[chunk_id_]), accum(accum_coord)); + scale_convertor(tCrS(scale_coord)[chunk_id_]), accum(accum_coord)); } } } @@ -1186,7 +1208,7 @@ struct CollectiveMmaArrayMixedInput< auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0); accum(accum_coord) = fma(intermediate_array[chunk_id_](accum_coord), - static_cast(tCrS(scale_coord)[chunk_id_]), accum(accum_coord)); + scale_convertor(tCrS(scale_coord)[chunk_id_]), accum(accum_coord)); } } } @@ -1275,7 +1297,7 @@ struct CollectiveMmaArrayMixedInput< int scale_idx = k_block / NumMMAsPerChunk; accum(accum_coord) = fma(intermediate(accum_coord), - static_cast(tCrS(scale_coord)[scale_idx]), accum(accum_coord)); + scale_convertor(tCrS(scale_coord)[scale_idx]), accum(accum_coord)); } } } From 5f1461919227a1979fa32af7fe546d9fd917a4be Mon Sep 17 00:00:00 2001 From: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> Date: Tue, 28 Oct 2025 08:59:10 +0000 Subject: [PATCH 3/3] Add weight interleaving to boost the conversion process Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> --- .../detail/collective/mixed_input_utils.hpp | 30 +++++++++- .../moe_gemm/moe_gemm_mixed_utils.cu | 55 +++++++++++++++++++ .../moe_gemm/moe_gemm_mixed_utils.h | 26 +++++++++ cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp | 29 ++++++++++ .../_torch/modules/fused_moe/quantization.py | 11 ++++ 5 files changed, 150 insertions(+), 1 deletion(-) create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu create mode 100644 cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp index 53dc9e053ad..0626c226b8e 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp @@ -99,6 +99,34 @@ __device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8(con return bf16x8_raw; } +__device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8_interleaved( + const __nv_fp4x8_storage_t fp4x8) +{ + // interleaved version + // input fp4x8: 7564 3120 + // output bf16x8: 7654 3210 + + __nv_bf16x8_storage_t bf16x8_raw; + __nv_bf16x2_storage_t* bf16x2_raw = reinterpret_cast<__nv_bf16x2_storage_t*>(&bf16x8_raw); + + __nv_fp8x4_storage_t h_fp8x4_0to1_bits = (fp4x8 & 0xC0C0C0C0U) >> 6; // 7632 + __nv_fp8x4_storage_t l_fp8x4_0to1_bits = (fp4x8 & 0x0C0C0C0CU) >> 2; // 5410 + + unsigned h4b_em_fp4x4 = (fp4x8 & 0x77770000U) >> 16U; + unsigned l4b_em_fp4x4 = (fp4x8 & 0x00007777U); + + __nv_fp8x4_storage_t h4b_2to9_bits = cvt_lut_bf16(h4b_em_fp4x4); // 7564 + __nv_fp8x4_storage_t l4b_2to9_bits = cvt_lut_bf16(l4b_em_fp4x4); // 3120 + + bf16x2_raw[0] = prmt(l_fp8x4_0to1_bits, l4b_2to9_bits, 0x5240U) << 6U; // 1 0 + bf16x2_raw[1] = prmt(h_fp8x4_0to1_bits, l4b_2to9_bits, 0x5341U) << 6U; // 3 2 + + bf16x2_raw[2] = prmt(l_fp8x4_0to1_bits, h4b_2to9_bits, 0x7260U) << 6U; // 5 4 + bf16x2_raw[3] = prmt(h_fp8x4_0to1_bits, h4b_2to9_bits, 0x7361U) << 6U; // 7 6 + + return bf16x8_raw; +} + template struct MixedGroupedGemmInputUtils { @@ -330,7 +358,7 @@ struct MixedGroupedGemmInputUtils auto&& src_ = cute::recast<__nv_fp4x8_storage_t>(src)(0); auto&& dst_ = cute::recast<__nv_bf16x8_storage_t>(dst)(0); - dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8(src_); + dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8_interleaved(src_); } /// Utilities to dequantize A. diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu new file mode 100644 index 00000000000..7486f1523eb --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moe_gemm_mixed_utils.h" + +namespace tensorrt_llm::kernels::cutlass_kernels +{ + +__global__ void interleave_fp4_for_Hopper_mixed_gemm_kernel( + uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols) +{ + for (int block_id = blockIdx.x; block_id < rows / 2; block_id += gridDim.x) + { + for (int col_id = threadIdx.x; col_id < cols / 2; col_id += blockDim.x) + { + int row_id = block_id / 8 * 16 + block_id % 8; + + int index_a = row_id * cols / 2 + col_id; + int index_b = (row_id + 8) * cols / 2 + col_id; + + uint8_t fp4x2_a = weight[index_a]; + uint8_t fp4x2_b = weight[index_b]; + + uint8_t fp4_temp_a = (fp4x2_a & 0xF0U) >> 4; + uint8_t fp4_temp_b = (fp4x2_b & 0x0FU) << 4; + + fp4x2_a = (fp4x2_a & 0x0FU) | fp4_temp_b; + fp4x2_b = (fp4x2_b & 0xF0U) | fp4_temp_a; + + weight_interleaved[index_a] = fp4x2_a; + weight_interleaved[index_b] = fp4x2_b; + } + } +} + +void interleave_fp4_for_Hopper_mixed_gemm(uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols) +{ + // column-major input + interleave_fp4_for_Hopper_mixed_gemm_kernel<<<1024, 1024>>>(weight, weight_interleaved, rows, cols); +} + +} // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h new file mode 100644 index 00000000000..0898dc75cd2 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace tensorrt_llm::kernels::cutlass_kernels +{ + +void interleave_fp4_for_Hopper_mixed_gemm(uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols); + +} diff --git a/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp b/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp index b6feba15e6b..d1f6f61d033 100644 --- a/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp +++ b/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp @@ -16,6 +16,7 @@ #include "tensorrt_llm/common/cudaBf16Wrapper.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h" #include "tensorrt_llm/thop/thUtils.h" #if defined(TORCH_VERSION_MAJOR) \ @@ -398,6 +399,31 @@ Tensor mxfp4_dequantize_unswizzled(Tensor weight, Tensor scale, int64_t group_si return dequant_weight; } +Tensor fp4_interleave_for_Hopper_mixed_gemm(Tensor weight) +{ + // weight (n, k / 2) + int const n = weight.size(0); + int const k = weight.size(1) * 2; + + CHECK_TH_CUDA(weight); + CHECK_CONTIGUOUS(weight); + + TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor"); + TORCH_CHECK(weight.dtype() == torch::kUInt8, "Weight must be a packed uint8 tensor"); + TORCH_CHECK(n % 16 == 0) + TORCH_CHECK(k % 16 == 0) + + Tensor weight_interleaved + = torch::empty({n, k / 2}, torch::dtype(torch::kUInt8).device(torch::kCUDA).requires_grad(false)); + + uint8_t* weight_ptr = get_ptr(weight); + uint8_t* weight_interleaved_ptr = get_ptr(weight_interleaved); + + interleave_fp4_for_Hopper_mixed_gemm(weight_ptr, weight_interleaved_ptr, n, k); + + return weight_interleaved; +} + } // namespace torch_ext // Utility methods that may be useful for preprocessing weights in torch. @@ -432,3 +458,6 @@ static auto subbyte_transpose = torch::RegisterOperators("trtllm::_subbyte_trans static auto mxfp4_dequantize_unswizzled = torch::RegisterOperators("trtllm::mxfp4_dequantize_unswizzled", &torch_ext::mxfp4_dequantize_unswizzled); + +static auto fp4_interleave_for_Hopper_mixed_gemm = torch::RegisterOperators( + "trtllm::fp4_interleave_for_Hopper_mixed_gemm", &torch_ext::fp4_interleave_for_Hopper_mixed_gemm); diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 1e56f90d5e9..158c80c1753 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -1390,9 +1390,11 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, pad_size_inter = module.intermediate_size_per_partition - w3_weight_shard.shape[ 0] if w3_weight_shard.ndim == 2: + # [intermediate_size, hidden_size] pad_size_hidden = module.hidden_size // 2 - w3_weight_shard.shape[1] pad_shape = (0, pad_size_hidden, 0, pad_size_inter) elif w3_weight_shard.ndim == 1: + # [intermediate_size] pad_shape = (0, pad_size_inter) else: raise NotImplementedError( @@ -1404,6 +1406,10 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0) + if w3_weight_shard.ndim == 2: + w31_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm( + w31_weight_shard) + dst_w3_w1_weight.copy_(w31_weight_shard.view(dst_w3_w1_weight.dtype), non_blocking=True) @@ -1433,6 +1439,11 @@ def load_expert_w2_weight(self, module: torch.nn.Module, f"Invalid shape of w2_weight_shard {w2_weight_shard.shape}") w2_weight_shard = torch.nn.functional.pad(w2_weight_shard, pad_shape) + + if w2_weight_shard.ndim == 2: + w2_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm( + w2_weight_shard) + dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), non_blocking=True)