-
Notifications
You must be signed in to change notification settings - Fork 1.8k
MXFP4 x BF16 CUTLASS MoE backend perf and profiling improvement on Hopper #8721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -304,27 +304,30 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm90(CutlassGemmConfig::Can | |||||||||||||||||||||||||||||||||
| if (has_w4afp8) | ||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||
| bool const has_coop_supported = sm90_supports_coop(tile_config); | ||||||||||||||||||||||||||||||||||
| std::set<MainloopScheduleType> mainloop_schedules{MainloopScheduleType::PINGPONG}; | ||||||||||||||||||||||||||||||||||
| if (has_coop_supported) | ||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||
| mainloop_schedules.insert(MainloopScheduleType::COOPERATIVE); | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| // It seems that ping-pong scheduler will never be selected. | ||||||||||||||||||||||||||||||||||
| // To shorten the tactic time, remove all alternative options involving ping-pong scheduler. | ||||||||||||||||||||||||||||||||||
| if (!has_coop_supported) | ||||||||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||||||||
| // Due to the limitation on the number of registers on SM, | ||||||||||||||||||||||||||||||||||
| // cooperative scheduler does not support CtaShape128x128x128B. | ||||||||||||||||||||||||||||||||||
| if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B) | ||||||||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+310
to
+315
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Add braces around single-statement if bodies. The coding guidelines require that if statements always be followed by brace-delimited statements. Both continue statements lack the required braces. As per coding guidelines. Apply this diff to add braces: - if (!has_coop_supported)
- continue;
+ if (!has_coop_supported)
+ {
+ continue;
+ }
// Due to the limitation on the number of registers on SM,
// cooperative scheduler does not support CtaShape128x128x128B.
- if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
- continue;
+ if (tile_config == CutlassTileConfigSM90::CtaShape128x128x128B)
+ {
+ continue;
+ }📝 Committable suggestion
Suggested change
🤖 Prompt for AI AgentsThere was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How much performance are we leaving on the table here? Is there a way to reduce the number of stages or otherwise relieve register pressure |
||||||||||||||||||||||||||||||||||
| MainloopScheduleType mainloop_schedule = MainloopScheduleType::COOPERATIVE; | ||||||||||||||||||||||||||||||||||
| auto const epilogue_schedule = EpilogueScheduleType::AUTO; | ||||||||||||||||||||||||||||||||||
| for (auto const& mainloop_schedule : mainloop_schedules) | ||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||
| CutlassGemmConfig candidate( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x1x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| candidate = CutlassGemmConfig( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x1x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| candidate = CutlassGemmConfig( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x2x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| candidate = CutlassGemmConfig( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x2x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| CutlassGemmConfig candidate( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x1x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| candidate = CutlassGemmConfig( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x1x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| candidate = CutlassGemmConfig( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_1x2x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| candidate = CutlassGemmConfig( | ||||||||||||||||||||||||||||||||||
| tile_config, mainloop_schedule, epilogue_schedule, ClusterShape::ClusterShape_2x2x1); | ||||||||||||||||||||||||||||||||||
| candidate_configs.push_back(candidate); | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,55 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| /* | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| * limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1
to
+15
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update copyright year to include 2025. Header stops at 2023; repository guidelines require current year on source files. As per coding guidelines. - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "moe_gemm_mixed_utils.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace tensorrt_llm::kernels::cutlass_kernels | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| __global__ void interleave_fp4_for_Hopper_mixed_gemm_kernel( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int block_id = blockIdx.x; block_id < rows / 2; block_id += gridDim.x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int col_id = threadIdx.x; col_id < cols / 2; col_id += blockDim.x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int row_id = block_id / 8 * 16 + block_id % 8; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int index_a = row_id * cols / 2 + col_id; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int index_b = (row_id + 8) * cols / 2 + col_id; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+25
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Enforce required shape preconditions and avoid repeated divisions. Kernel assumes {
- for (int block_id = blockIdx.x; block_id < rows / 2; block_id += gridDim.x)
+ // Preconditions: rows are multiples of 16; cols are even (#fp4 is even -> bytes).
+ if ((rows & 0xF) != 0 || (cols & 0x1) != 0)
+ {
+ return;
+ }
+
+ int const rowPairs = rows / 2;
+ int const colBytes = cols / 2;
+
+ for (int block_id = blockIdx.x; block_id < rowPairs; block_id += gridDim.x)
{
- for (int col_id = threadIdx.x; col_id < cols / 2; col_id += blockDim.x)
+ for (int col_id = threadIdx.x; col_id < colBytes; col_id += blockDim.x)
{
- int row_id = block_id / 8 * 16 + block_id % 8;
+ int const row_id = (block_id / 8) * 16 + (block_id % 8);🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t fp4x2_a = weight[index_a]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t fp4x2_b = weight[index_b]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t fp4_temp_a = (fp4x2_a & 0xF0U) >> 4; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uint8_t fp4_temp_b = (fp4x2_b & 0x0FU) << 4; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fp4x2_a = (fp4x2_a & 0x0FU) | fp4_temp_b; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fp4x2_b = (fp4x2_b & 0xF0U) | fp4_temp_a; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_interleaved[index_a] = fp4x2_a; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_interleaved[index_b] = fp4x2_b; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void interleave_fp4_for_Hopper_mixed_gemm(uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // column-major input | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| interleave_fp4_for_Hopper_mixed_gemm_kernel<<<1024, 1024>>>(weight, weight_interleaved, rows, cols); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } // namespace tensorrt_llm::kernels::cutlass_kernels | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| /* | ||
| * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <cstdint> | ||
|
|
||
| namespace tensorrt_llm::kernels::cutlass_kernels | ||
| { | ||
|
|
||
| void interleave_fp4_for_Hopper_mixed_gemm(uint8_t* weight, uint8_t* weight_interleaved, int const rows, int const cols); | ||
|
|
||
| } |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1390,9 +1390,11 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, | |||||||||||||||||||||||||||
| pad_size_inter = module.intermediate_size_per_partition - w3_weight_shard.shape[ | ||||||||||||||||||||||||||||
| 0] | ||||||||||||||||||||||||||||
| if w3_weight_shard.ndim == 2: | ||||||||||||||||||||||||||||
| # [intermediate_size, hidden_size] | ||||||||||||||||||||||||||||
| pad_size_hidden = module.hidden_size // 2 - w3_weight_shard.shape[1] | ||||||||||||||||||||||||||||
| pad_shape = (0, pad_size_hidden, 0, pad_size_inter) | ||||||||||||||||||||||||||||
| elif w3_weight_shard.ndim == 1: | ||||||||||||||||||||||||||||
| # [intermediate_size] | ||||||||||||||||||||||||||||
| pad_shape = (0, pad_size_inter) | ||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||
| raise NotImplementedError( | ||||||||||||||||||||||||||||
|
|
@@ -1404,6 +1406,10 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module, | |||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if w3_weight_shard.ndim == 2: | ||||||||||||||||||||||||||||
| w31_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm( | ||||||||||||||||||||||||||||
| w31_weight_shard) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| dst_w3_w1_weight.copy_(w31_weight_shard.view(dst_w3_w1_weight.dtype), | ||||||||||||||||||||||||||||
| non_blocking=True) | ||||||||||||||||||||||||||||
|
Comment on lines
+1409
to
1414
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainGuard CUDA-only interleave and preserve original device fp4_interleave_for_Hopper_mixed_gemm requires CUDA tensors (C++ side checks). Today dst_w3_w1_weight can be CPU. Move to CUDA for the op, then back, and assert 16-aligned shapes to avoid C++ TORCH_CHECK. - if w3_weight_shard.ndim == 2:
- w31_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(
- w31_weight_shard)
+ if w3_weight_shard.ndim == 2:
+ n, half_k = w31_weight_shard.shape
+ # C++ op requires n % 16 == 0 and (2*half_k) % 16 == 0
+ assert (n % 16 == 0) and ((2 * half_k) % 16 == 0), "FP4 interleave requires 16-aligned shapes"
+ src_dev = w31_weight_shard.device
+ if src_dev.type != "cuda":
+ w31_weight_shard = w31_weight_shard.cuda(non_blocking=True)
+ w31_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(w31_weight_shard)
+ w31_weight_shard = w31_weight_shard.to(src_dev, non_blocking=True)
+ else:
+ w31_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(w31_weight_shard)Please confirm the module’s parameters are on CUDA at load time for this method; otherwise this fix is required to prevent runtime TORCH_CHECK failures from the C++ op. I'll verify this review comment by examining the code context, parameter initialization, and device handling. Let me search the codebase for relevant information. Length of output: 2811 🏁 Script executed: #!/bin/bash
# Search for dst_w3_w1_weight initialization and device logic
rg -n -B5 -A5 'dst_w3_w1_weight' tensorrt_llm/_torch/modules/fused_moe/quantization.pyLength of output: 31254 🏁 Script executed: #!/bin/bash
# Find the method definition containing this code
rg -n -B20 'fp4_interleave_for_Hopper_mixed_gemm' tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -50Length of output: 2226 🏁 Script executed: #!/bin/bash
# Search for how this method is called and what device assumptions exist
rg -n 'def.*quantize' tensorrt_llm/_torch/modules/fused_moe/quantization.pyLength of output: 71 🏁 Script executed: #!/bin/bash
# Look for CUDA-related assertions or device checks in similar operations
rg -n 'cuda\(\|device\|\.to\(' tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -30Length of output: 0 Length of output: 636 🏁 Script executed: #!/bin/bash
# Find class definition to see where tensor allocation happens
rg -n 'class.*MoE.*:' tensorrt_llm/_torch/modules/fused_moe/quantization.pyLength of output: 1754 🏁 Script executed: #!/bin/bash
# Check if there's an __init__ or allocation method that sets up these tensors
rg -n -A10 'def.*__init__' tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -40Length of output: 0 🏁 Script executed: #!/bin/bash
# Search for other similar ops to see if they have device handling
rg -n 'torch.ops.trtllm' tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -20Length of output: 1498 🏁 Script executed: #!/bin/bash
# Look for where load_expert_w3_w1_weight is called from
rg -n 'load_expert_w3_w1_weight' tensorrt_llm/_torch/modules/fused_moe/quantization.py | head -10Length of output: 623 🏁 Script executed: #!/bin/bash
# Examine another class's load_expert_w3_w1_weight that has cuda assert (line 1887)
sed -n '1883,1925p' tensorrt_llm/_torch/modules/fused_moe/quantization.pyLength of output: 2085 Length of output: 364 🏁 Script executed: #!/bin/bash
# Look at the load_expert_weights_to_dst method that calls load_expert_w3_w1_weight
sed -n '210,270p' tensorrt_llm/_torch/modules/fused_moe/quantization.pyLength of output: 3091 🏁 Script executed: #!/bin/bash
# Check how torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm is defined/documented
find tensorrt_llm -name "*.py" -o -name "*.cu" -o -name "*.cpp" | xargs grep -l 'fp4_interleave_for_Hopper_mixed_gemm' 2>/dev/null | head -5Length of output: 53 🏁 Script executed: #!/bin/bash
# Compare all torch.ops.trtllm usages and check which ones have cuda assertions before them
rg -n -B5 'torch.ops.trtllm.(shuffle_matrix|fp4_interleave|block_scale_interleave)' tensorrt_llm/_torch/modules/fused_moe/quantization.py | grep -E '(assert|cuda|device|torch.ops.trtllm)' | head -40Length of output: 1583 Length of output: 758 🏁 Script executed: #!/bin/bash
# Search for where module.w3_w1_weight tensor is allocated
rg -n 'w3_w1_weight\s*=' tensorrt_llm/_torch/modules/fused_moe/ | head -20Length of output: 577 🏁 Script executed: #!/bin/bash
# Search for module setup or initialization that might set device
rg -n 'def.*__init__\|w3_w1_weight.*torch\.' tensorrt_llm/_torch/modules/fused_moe/ | head -20Length of output: 0 🏁 Script executed: #!/bin/bash
# Check if there are any comments or checks about device in the FusedMoE module
rg -n 'cpu\|CUDA\|device.*cuda' tensorrt_llm/_torch/modules/fused_moe/ | head -15Length of output: 0 Length of output: 1222 🏁 Script executed: #!/bin/bash
# Search for the function/method that contains line 178
sed -n '153,185p' tensorrt_llm/_torch/modules/fused_moe/quantization.pyLength of output: 1234 🏁 Script executed: #!/bin/bash
# Check what arguments are passed to torch.empty
rg -n -A2 'torch.empty.*w3_w1_weight_shape' tensorrt_llm/_torch/modules/fused_moe/quantization.pyLength of output: 205 Guard CUDA-only interleave and preserve original device The review comment correctly identifies an issue. The fp4_interleave_for_Hopper_mixed_gemm C++ operation requires CUDA tensors, but the code at lines 1409–1411 does not guard this requirement. While torch.empty creates the w3_w1_weight parameter without specifying a device, allowing it to be allocated on CPU, a code comment explicitly notes that tensors don't have to be CUDA, e.g., for "online EPLB". Other similar methods in the same file (e.g., NVFP4TRTLLMGenFusedMoEMethod at line 1888) include |
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
@@ -1433,6 +1439,11 @@ def load_expert_w2_weight(self, module: torch.nn.Module, | |||||||||||||||||||||||||||
| f"Invalid shape of w2_weight_shard {w2_weight_shard.shape}") | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| w2_weight_shard = torch.nn.functional.pad(w2_weight_shard, pad_shape) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if w2_weight_shard.ndim == 2: | ||||||||||||||||||||||||||||
| w2_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm( | ||||||||||||||||||||||||||||
| w2_weight_shard) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
Comment on lines
+1443
to
+1446
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same CUDA guard for w2 path Mirror the device/shape guard to avoid failures on CPU tensors. - if w2_weight_shard.ndim == 2:
- w2_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(
- w2_weight_shard)
+ if w2_weight_shard.ndim == 2:
+ n, half_k = w2_weight_shard.shape
+ assert (n % 16 == 0) and ((2 * half_k) % 16 == 0), "FP4 interleave requires 16-aligned shapes"
+ src_dev = w2_weight_shard.device
+ if src_dev.type != "cuda":
+ w2_weight_shard = w2_weight_shard.cuda(non_blocking=True)
+ w2_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(w2_weight_shard)
+ w2_weight_shard = w2_weight_shard.to(src_dev, non_blocking=True)
+ else:
+ w2_weight_shard = torch.ops.trtllm.fp4_interleave_for_Hopper_mixed_gemm(w2_weight_shard)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||
| dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype), | ||||||||||||||||||||||||||||
| non_blocking=True) | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What models have you tested this with? I am hesitant to remove this without a comprehensive sweep of multiple model architectures like Mixtral, DeepSeek, Llama4 and GPT-OSS. Its hard to say what the next DeepSeek moment will look like.
I also dont think tactic selection time is actually a significant concern. There are lots of tactics sure, but weight loading is usually just as long. Maybe we should add a fast profile mode that users can opt into