Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,34 @@ __device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8(con
return bf16x8_raw;
}

__device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8_interleaved(
const __nv_fp4x8_storage_t fp4x8)
{
// interleaved version
// input fp4x8: 7564 3120
// output bf16x8: 7654 3210

__nv_bf16x8_storage_t bf16x8_raw;
__nv_bf16x2_storage_t* bf16x2_raw = reinterpret_cast<__nv_bf16x2_storage_t*>(&bf16x8_raw);

__nv_fp8x4_storage_t h_fp8x4_0to1_bits = (fp4x8 & 0xC0C0C0C0U) >> 6; // 7632
__nv_fp8x4_storage_t l_fp8x4_0to1_bits = (fp4x8 & 0x0C0C0C0CU) >> 2; // 5410

unsigned h4b_em_fp4x4 = (fp4x8 & 0x77770000U) >> 16U;
unsigned l4b_em_fp4x4 = (fp4x8 & 0x00007777U);

__nv_fp8x4_storage_t h4b_2to9_bits = cvt_lut_bf16(h4b_em_fp4x4); // 7564
__nv_fp8x4_storage_t l4b_2to9_bits = cvt_lut_bf16(l4b_em_fp4x4); // 3120

bf16x2_raw[0] = prmt(l_fp8x4_0to1_bits, l4b_2to9_bits, 0x5240U) << 6U; // 1 0
bf16x2_raw[1] = prmt(h_fp8x4_0to1_bits, l4b_2to9_bits, 0x5341U) << 6U; // 3 2

bf16x2_raw[2] = prmt(l_fp8x4_0to1_bits, h4b_2to9_bits, 0x7260U) << 6U; // 5 4
bf16x2_raw[3] = prmt(h_fp8x4_0to1_bits, h4b_2to9_bits, 0x7361U) << 6U; // 7 6

return bf16x8_raw;
}

template <class Collective>
struct MixedGroupedGemmInputUtils
{
Expand Down Expand Up @@ -330,7 +358,7 @@ struct MixedGroupedGemmInputUtils
auto&& src_ = cute::recast<__nv_fp4x8_storage_t>(src)(0);
auto&& dst_ = cute::recast<__nv_bf16x8_storage_t>(dst)(0);

dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8(src_);
dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8_interleaved(src_);
}

/// Utilities to dequantize A.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,28 @@ struct CollectiveMmaArrayMixedInput<
}

/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

// Override the FP8 conversion in CUTLASS to enforce the intended compiler behavior.
template <class T>
CUTLASS_DEVICE float scale_convertor(T scale)
{
if constexpr (cute::is_same_v<ElementA, cutlass::float_e2m1_t>)
{

cutlass::float_ue8m0_t scale_ue8m0 = scale;

uint32_t temp = 0;
temp = (temp | *reinterpret_cast<uint8_t*>(&scale_ue8m0)) << 23;
return *reinterpret_cast<float*>(&temp);
}
else
{
return static_cast<float>(scale);
}
}

/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <class FrgTensorC>
Expand Down Expand Up @@ -1084,12 +1106,12 @@ struct CollectiveMmaArrayMixedInput<
if (chunk_id_ == 0)
{
accum(accum_coord) = intermediate_array[chunk_id_](accum_coord)
* static_cast<float>(tCrS(scale_coord)[0]);
* scale_convertor(tCrS(scale_coord)[0]);
}
else
{
accum(accum_coord) = fma(intermediate_array[chunk_id_](accum_coord),
static_cast<float>(tCrS(scale_coord)[chunk_id_]), accum(accum_coord));
scale_convertor(tCrS(scale_coord)[chunk_id_]), accum(accum_coord));
}
}
}
Expand Down Expand Up @@ -1186,7 +1208,7 @@ struct CollectiveMmaArrayMixedInput<
auto scale_coord = make_coord(make_tuple(0, m, 0), mma_m, 0);

accum(accum_coord) = fma(intermediate_array[chunk_id_](accum_coord),
static_cast<float>(tCrS(scale_coord)[chunk_id_]), accum(accum_coord));
scale_convertor(tCrS(scale_coord)[chunk_id_]), accum(accum_coord));
}
}
}
Expand Down Expand Up @@ -1275,7 +1297,7 @@ struct CollectiveMmaArrayMixedInput<
int scale_idx = k_block / NumMMAsPerChunk;

accum(accum_coord) = fma(intermediate(accum_coord),
static_cast<float>(tCrS(scale_coord)[scale_idx]), accum(accum_coord));
scale_convertor(tCrS(scale_coord)[scale_idx]), accum(accum_coord));
}
}
}
Expand Down
43 changes: 23 additions & 20 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,27 +304,30 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm90(CutlassGemmConfig::Can
if (has_w4afp8)
{
bool const has_coop_supported = sm90_supports_coop(tile_config);
std::set<MainloopScheduleType> mainloop_schedules{MainloopScheduleType::PINGPONG};
if (has_coop_supported)
{
mainloop_schedules.insert(MainloopScheduleType::COOPERATIVE);
}

// It seems that ping-pong scheduler will never be selected.
// To shorten the tactic time, remove all alternative options involving ping-pong scheduler.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What models have you tested this with? I am hesitant to remove this without a comprehensive sweep of multiple model architectures like Mixtral, DeepSeek, Llama4 and GPT-OSS. Its hard to say what the next DeepSeek moment will look like.
I also dont think tactic selection time is actually a significant concern. There are lots of tactics sure, but weight loading is usually just as long. Maybe we should add a fast profile mode that users can opt into

if (!has_coop_supported)
continue;
// Due to the limitation on the number of registers on SM,
// cooperative scheduler does not support CtaShape128x128x128B.
if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
continue;
Comment on lines +310 to +315
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add braces around single-statement if bodies.

The coding guidelines require that if statements always be followed by brace-delimited statements. Both continue statements lack the required braces.

As per coding guidelines.

Apply this diff to add braces:

-            if (!has_coop_supported)
-                continue;
+            if (!has_coop_supported)
+            {
+                continue;
+            }
             // Due to the limitation on the number of registers on SM,
             // cooperative scheduler does not support CtaShape128x128x128B.
-            if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
-                continue;
+            if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
+            {
+                continue;
+            }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (!has_coop_supported)
continue;
// Due to the limitation on the number of registers on SM,
// cooperative scheduler does not support CtaShape128x128x128B.
if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
continue;
if (!has_coop_supported)
{
continue;
}
// Due to the limitation on the number of registers on SM,
// cooperative scheduler does not support CtaShape128x128x128B.
if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
{
continue;
}
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp around lines
310 to 315, two if statements use single-statement bodies with continue and must
be converted to brace-delimited blocks per the coding guidelines; change each
`if (condition) continue;` to `if (condition) { continue; }`, preserving
existing indentation and spacing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much performance are we leaving on the table here? Is there a way to reduce the number of stages or otherwise relieve register pressure

MainloopScheduleType mainloop_schedule = MainloopScheduleType::COOPERATIVE;
auto const epilogue_schedule = EpilogueScheduleType::AUTO;
for (auto const& mainloop_schedule : mainloop_schedules)
{
CutlassGemmConfig candidate(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x1x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x1x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x2x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x2x1);
candidate_configs.push_back(candidate);
}

CutlassGemmConfig candidate(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x1x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x1x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x2x1);
candidate_configs.push_back(candidate);
candidate = CutlassGemmConfig(
tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x2x1);
candidate_configs.push_back(candidate);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
Comment on lines +1 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Update copyright year to include 2025.

Header stops at 2023; repository guidelines require current year on source files. As per coding guidelines.

- * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2020-2025, NVIDIA CORPORATION.  All rights reserved.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu
around lines 1 to 15, the copyright header ends with "2023" but repository
guidelines require the current year; update the copyright range to include 2025
(e.g., "2020-2025" or "2020-2023, 2025" per project convention) and ensure the
license block formatting is preserved exactly as before.


#include "moe_gemm_mixed_utils.h"

namespace tensorrt_llm::kernels::cutlass_kernels
{

__global__ void interleave_fp4_for_Hopper_mixed_gemm_kernel(
uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols)
{
for (int block_id = blockIdx.x; block_id < rows / 2; block_id += gridDim.x)
{
for (int col_id = threadIdx.x; col_id < cols / 2; col_id += blockDim.x)
{
int row_id = block_id / 8 * 16 + block_id % 8;

int index_a = row_id * cols / 2 + col_id;
int index_b = (row_id + 8) * cols / 2 + col_id;

Comment on lines +25 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Enforce required shape preconditions and avoid repeated divisions.

Kernel assumes rows % 16 == 0 and cols % 2 == 0. Without checks, odd sizes can index OOB. Also factor rowPairs/colBytes once. As per coding guidelines.

 {
-    for (int block_id = blockIdx.x; block_id < rows / 2; block_id += gridDim.x)
+    // Preconditions: rows are multiples of 16; cols are even (#fp4 is even -> bytes).
+    if ((rows & 0xF) != 0 || (cols & 0x1) != 0)
+    {
+        return;
+    }
+
+    int const rowPairs = rows / 2;
+    int const colBytes = cols / 2;
+
+    for (int block_id = blockIdx.x; block_id < rowPairs; block_id += gridDim.x)
     {
-        for (int col_id = threadIdx.x; col_id < cols / 2; col_id += blockDim.x)
+        for (int col_id = threadIdx.x; col_id < colBytes; col_id += blockDim.x)
         {
-            int row_id = block_id / 8 * 16 + block_id % 8;
+            int const row_id = (block_id / 8) * 16 + (block_id % 8);
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.cu
around lines 25-33, add explicit precondition checks that rows % 16 == 0 and
cols % 2 == 0 (return or assert/error if violated) to prevent OOB indexing, and
refactor the repeated divisions by computing locals once (e.g., halfRows =
rows/2 and halfCols = cols/2 or rowPairs = rows/16 and colBytes = cols/2) and
use those locals in the loops and index calculations instead of recalculating
rows/2 and cols/2 repeatedly.

uint8_t fp4x2_a = weight[index_a];
uint8_t fp4x2_b = weight[index_b];

uint8_t fp4_temp_a = (fp4x2_a & 0xF0U) >> 4;
uint8_t fp4_temp_b = (fp4x2_b & 0x0FU) << 4;

fp4x2_a = (fp4x2_a & 0x0FU) | fp4_temp_b;
fp4x2_b = (fp4x2_b & 0xF0U) | fp4_temp_a;

weight_interleaved[index_a] = fp4x2_a;
weight_interleaved[index_b] = fp4x2_b;
}
}
}

void interleave_fp4_for_Hopper_mixed_gemm(uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols)
{
// column-major input
interleave_fp4_for_Hopper_mixed_gemm_kernel<<<1024, 1024>>>(weight, weight_interleaved, rows, cols);
}

} // namespace tensorrt_llm::kernels::cutlass_kernels
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cstdint>

namespace tensorrt_llm::kernels::cutlass_kernels
{

void interleave_fp4_for_Hopper_mixed_gemm(uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols);

}
29 changes: 29 additions & 0 deletions cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h"
#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_mixed_utils.h"
#include "tensorrt_llm/thop/thUtils.h"

#if defined(TORCH_VERSION_MAJOR) \
Expand Down Expand Up @@ -398,6 +399,31 @@ Tensor mxfp4_dequantize_unswizzled(Tensor weight, Tensor scale, int64_t group_si
return dequant_weight;
}

Tensor fp4_interleave_for_Hopper_mixed_gemm(Tensor weight)
{
// weight (n, k / 2)
int const n = weight.size(0);
int const k = weight.size(1) * 2;

CHECK_TH_CUDA(weight);
CHECK_CONTIGUOUS(weight);

TORCH_CHECK(weight.numel() != 0, "weight should not be empty tensor");
TORCH_CHECK(weight.dtype() == torch::kUInt8, "Weight must be a packed uint8 tensor");
TORCH_CHECK(n % 16 == 0)
TORCH_CHECK(k % 16 == 0)

Tensor weight_interleaved
= torch::empty({n, k / 2}, torch::dtype(torch::kUInt8).device(torch::kCUDA).requires_grad(false));

uint8_t* weight_ptr = get_ptr<uint8_t>(weight);
uint8_t* weight_interleaved_ptr = get_ptr<uint8_t>(weight_interleaved);

interleave_fp4_for_Hopper_mixed_gemm(weight_ptr, weight_interleaved_ptr, n, k);

return weight_interleaved;
}

} // namespace torch_ext

// Utility methods that may be useful for preprocessing weights in torch.
Expand Down Expand Up @@ -432,3 +458,6 @@ static auto subbyte_transpose = torch::RegisterOperators("trtllm::_subbyte_trans

static auto mxfp4_dequantize_unswizzled
= torch::RegisterOperators("trtllm::mxfp4_dequantize_unswizzled", &torch_ext::mxfp4_dequantize_unswizzled);

static auto fp4_interleave_for_Hopper_mixed_gemm = torch::RegisterOperators(
"trtllm::fp4_interleave_for_Hopper_mixed_gemm", &torch_ext::fp4_interleave_for_Hopper_mixed_gemm);
11 changes: 11 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,9 +1390,11 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
pad_size_inter = module.intermediate_size_per_partition - w3_weight_shard.shape[
0]
if w3_weight_shard.ndim == 2:
# [intermediate_size, hidden_size]
pad_size_hidden = module.hidden_size // 2 - w3_weight_shard.shape[1]
pad_shape = (0, pad_size_hidden, 0, pad_size_inter)
elif w3_weight_shard.ndim == 1:
# [intermediate_size]
pad_shape = (0, pad_size_inter)
else:
raise NotImplementedError(
Expand All @@ -1404,6 +1406,10 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,

w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0)

if w3_weight_shard.ndim == 2:
w31_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(
w31_weight_shard)

dst_w3_w1_weight.copy_(w31_weight_shard.view(dst_w3_w1_weight.dtype),
non_blocking=True)
Comment on lines +1409 to 1414
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Guard CUDA-only interleave and preserve original device

fp4_interleave_for_Hopper_mixed_gemm requires CUDA tensors (C++ side checks). Today dst_w3_w1_weight can be CPU. Move to CUDA for the op, then back, and assert 16-aligned shapes to avoid C++ TORCH_CHECK.

-        if w3_weight_shard.ndim == 2:
-            w31_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(
-                w31_weight_shard)
+        if w3_weight_shard.ndim == 2:
+            n, half_k = w31_weight_shard.shape
+            # C++ op requires n % 16 == 0 and (2*half_k) % 16 == 0
+            assert (n % 16 == 0) and ((2 * half_k) % 16 == 0), "FP4 interleave requires 16-aligned shapes"
+            src_dev = w31_weight_shard.device
+            if src_dev.type != "cuda":
+                w31_weight_shard = w31_weight_shard.cuda(non_blocking=True)
+                w31_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(w31_weight_shard)
+                w31_weight_shard = w31_weight_shard.to(src_dev, non_blocking=True)
+            else:
+                w31_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(w31_weight_shard)

Please confirm the module’s parameters are on CUDA at load time for this method; otherwise this fix is required to prevent runtime TORCH_CHECK failures from the C++ op.


I'll verify this review comment by examining the code context, parameter initialization, and device handling. Let me search the codebase for relevant information.


---

🏁 Script executed:

```shell
#!/bin/bash
# Examine context around the target lines and the broader method
cat -n tensorrt_llm/_torch/modules/fused_moe/quantization.py | sed -n '1380,1430p'

Length of output: 2811


🏁 Script executed:

#!/bin/bash
# Search for dst_w3_w1_weight initialization and device logic
rg -n -B5 -A5 'dst_w3_w1_weight' tensorrt_llm/_torch/modules/fused_moe/quantization.py

Length of output: 31254


🏁 Script executed:

#!/bin/bash
# Find the method definition containing this code
rg -n -B20 'fp4_interleave_for_Hopper_mixed_gemm' tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -50

Length of output: 2226


🏁 Script executed:

#!/bin/bash
# Search for how this method is called and what device assumptions exist
rg -n 'def.*quantize' tensorrt_llm/_torch/modules/fused_moe/quantization.py

Length of output: 71


🏁 Script executed:

#!/bin/bash
# Look for CUDA-related assertions or device checks in similar operations
rg -n 'cuda\(\|device\|\.to\(' tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -30

Length of output: 0



---

🏁 Script executed:

```shell
#!/bin/bash
# Search for where dst_w3_w1_weights_tensor is allocated/created
rg -n 'dst_w3_w1_weight' tensorrt_llm/_torch/modules/fused_moe/quantization.py | grep -E '(cuda|\.to\(|device|empty|zeros|ones)' | head -20

Length of output: 636


🏁 Script executed:

#!/bin/bash
# Find class definition to see where tensor allocation happens
rg -n 'class.*MoE.*:' tensorrt_llm/_torch/modules/fused_moe/quantization.py

Length of output: 1754


🏁 Script executed:

#!/bin/bash
# Check if there's an __init__ or allocation method that sets up these tensors
rg -n -A10 'def.*__init__' tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -40

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Search for other similar ops to see if they have device handling
rg -n 'torch.ops.trtllm' tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -20

Length of output: 1498


🏁 Script executed:

#!/bin/bash
# Look for where load_expert_w3_w1_weight is called from
rg -n 'load_expert_w3_w1_weight' tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -10

Length of output: 623


🏁 Script executed:

#!/bin/bash
# Examine another class's load_expert_w3_w1_weight that has cuda assert (line 1887)
sed -n '1883,1925p' tensorrt_llm/_torch/modules/fused_moe/quantization.py

Length of output: 2085



---

🏁 Script executed:

```shell
#!/bin/bash
# Find where WFP4A16FusedMoEMethod is instantiated/used
rg -n 'WFP4A16FusedMoEMethod' tensorrt_llm/ -r 'FOUND: $0'

Length of output: 364


🏁 Script executed:

#!/bin/bash
# Look at the load_expert_weights_to_dst method that calls load_expert_w3_w1_weight
sed -n '210,270p' tensorrt_llm/_torch/modules/fused_moe/quantization.py

Length of output: 3091


🏁 Script executed:

#!/bin/bash
# Check how torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm is defined/documented
find tensorrt_llm -name "*.py" -o -name "*.cu" -o -name "*.cpp" | xargs grep -l 'fp4_interleave_for_Hopper_mixed_gemm' 2>/dev/null | head -5

Length of output: 53


🏁 Script executed:

#!/bin/bash
# Compare all torch.ops.trtllm usages and check which ones have cuda assertions before them
rg -n -B5 'torch.ops.trtllm.(shuffle_matrix|fp4_interleave|block_scale_interleave)' tensorrt_llm/_torch/modules/fused_moe/quantization.py | grep -E '(assert|cuda|device|torch.ops.trtllm)' | head -40

Length of output: 1583



---

🏁 Script executed:

```shell
#!/bin/bash
# Look at fused_moe_cutlass.py to see context of WFP4A16FusedMoEMethod instantiation
sed -n '210,225p' tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Length of output: 758


🏁 Script executed:

#!/bin/bash
# Search for where module.w3_w1_weight tensor is allocated
rg -n 'w3_w1_weight\s*=' tensorrt_llm/_torch/modules/fused_moe/ | head -20

Length of output: 577


🏁 Script executed:

#!/bin/bash
# Search for module setup or initialization that might set device
rg -n 'def.*__init__\|w3_w1_weight.*torch\.' tensorrt_llm/_torch/modules/fused_moe/ | head -20

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Check if there are any comments or checks about device in the FusedMoE module
rg -n 'cpu\|CUDA\|device.*cuda' tensorrt_llm/_torch/modules/fused_moe/ | head -15

Length of output: 0



---

🏁 Script executed:

```shell
#!/bin/bash
# Look at line 178 context where w3_w1_weight is allocated
sed -n '170,195p' tensorrt_llm/_torch/modules/fused_moe/quantization.py

Length of output: 1222


🏁 Script executed:

#!/bin/bash
# Search for the function/method that contains line 178
sed -n '153,185p' tensorrt_llm/_torch/modules/fused_moe/quantization.py

Length of output: 1234


🏁 Script executed:

#!/bin/bash
# Check what arguments are passed to torch.empty
rg -n -A2 'torch.empty.*w3_w1_weight_shape' tensorrt_llm/_torch/modules/fused_moe/quantization.py

Length of output: 205


Guard CUDA-only interleave and preserve original device

The review comment correctly identifies an issue. The fp4_interleave_for_Hopper_mixed_gemm C++ operation requires CUDA tensors, but the code at lines 1409–1411 does not guard this requirement. While torch.empty creates the w3_w1_weight parameter without specifying a device, allowing it to be allocated on CPU, a code comment explicitly notes that tensors don't have to be CUDA, e.g., for "online EPLB".

Other similar methods in the same file (e.g., NVFP4TRTLLMGenFusedMoEMethod at line 1888) include assert device.type == "cuda" before calling torch.ops.trtllm operations, showing the pattern that should be followed here. The proposed fix—moving tensors to CUDA for the operation, executing it, and moving back while asserting 16-aligned shapes—is appropriate.


Expand Down Expand Up @@ -1433,6 +1439,11 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
f"Invalid shape of w2_weight_shard {w2_weight_shard.shape}")

w2_weight_shard = torch.nn.functional.pad(w2_weight_shard, pad_shape)

if w2_weight_shard.ndim == 2:
w2_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(
w2_weight_shard)

Comment on lines +1443 to +1446
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Same CUDA guard for w2 path

Mirror the device/shape guard to avoid failures on CPU tensors.

-        if w2_weight_shard.ndim == 2:
-            w2_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(
-                w2_weight_shard)
+        if w2_weight_shard.ndim == 2:
+            n, half_k = w2_weight_shard.shape
+            assert (n % 16 == 0) and ((2 * half_k) % 16 == 0), "FP4 interleave requires 16-aligned shapes"
+            src_dev = w2_weight_shard.device
+            if src_dev.type != "cuda":
+                w2_weight_shard = w2_weight_shard.cuda(non_blocking=True)
+                w2_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(w2_weight_shard)
+                w2_weight_shard = w2_weight_shard.to(src_dev, non_blocking=True)
+            else:
+                w2_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(w2_weight_shard)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if w2_weight_shard.ndim == 2:
w2_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(
w2_weight_shard)
if w2_weight_shard.ndim == 2:
n, half_k = w2_weight_shard.shape
assert (n % 16 == 0) and ((2 * half_k) % 16 == 0), "FP4 interleave requires 16-aligned shapes"
src_dev = w2_weight_shard.device
if src_dev.type != "cuda":
w2_weight_shard = w2_weight_shard.cuda(non_blocking=True)
w2_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(w2_weight_shard)
w2_weight_shard = w2_weight_shard.to(src_dev, non_blocking=True)
else:
w2_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(w2_weight_shard)
🤖 Prompt for AI Agents
tensorrt_llm/_torch/modules/fused_moe/quantization.py around lines 1443 to 1446:
the interleave call for w2_weight_shard lacks the same device/shape guard as the
w1 path, which can cause failures on CPU tensors; wrap the call so it only runs
when w2_weight_shard is a CUDA tensor and has ndim == 2 (mirror the existing
guard used for w1), i.e., check device/is_cuda and ndim before invoking
torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm so CPU or non-2D tensors
skip the CUDA-specific op.

dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype),
non_blocking=True)

Expand Down
Loading