Skip to content

Commit 313cfc3

Browse files
committed
NVFP4CuteDslFusedMoEMethod
Signed-off-by: Enwei Zhu <[email protected]>
1 parent 98948a4 commit 313cfc3

File tree

3 files changed

+42
-44
lines changed

3 files changed

+42
-44
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ...model_config import ModelConfig
1111
from ...utils import AuxStreamType, Fp4QuantizedTensor, ceil_div
1212
from .fused_moe_cutlass import CutlassFusedMoE
13-
from .quantization import MoEWeightLoadingMode
13+
from .quantization import MoEWeightLoadingMode, NVFP4CuteDslFusedMoEMethod
1414
from .routing import BaseMoeRoutingMethod
1515

1616

@@ -150,8 +150,7 @@ def cute_dsl_nvfp4_grouped_gemm_ref(
150150

151151

152152
class CuteDslFusedMoE(CutlassFusedMoE):
153-
"""
154-
Python Flow of Fused Mixture of Experts (MoE) Layer.
153+
"""CuteDSL flow of fused mixture of experts (MoE) Layer.
155154
156155
Args:
157156
num_experts (int): Number of experts in the MoE layer.
@@ -162,11 +161,6 @@ class CuteDslFusedMoE(CutlassFusedMoE):
162161
dtype (Optional[torch.dtype]): Data type for the weights.
163162
reduce_results (bool): Whether to reduce the results across devices.
164163
model_config (ModelConfig): Configuration object for the model.
165-
166-
This backend is composed of multiple custom ops:
167-
1. moe_permute_op: permute the input tensor and the expert selected tensor.
168-
2. cute_dsl_fp8_group_blockwise_gemm_ref: a reference implementation of the cute_dsl_fp8_group_blockwise_gemm.
169-
3. moe_finalize_scale_op: finalize the scale of the output tensor.
170164
"""
171165

172166
def __init__(
@@ -201,6 +195,13 @@ def __init__(
201195
layer_idx=layer_idx,
202196
)
203197

198+
def _get_quant_method(self):
199+
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
200+
exclude_kv_cache=True):
201+
if self.quant_config.layer_quant_mode.has_nvfp4():
202+
return NVFP4CuteDslFusedMoEMethod()
203+
return super()._get_quant_method()
204+
204205
def forward_chunk_unquantized(
205206
self,
206207
x: Union[torch.Tensor, Fp4QuantizedTensor],

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,7 @@ def set_strides(workspace: torch.Tensor, g: int, m: int, k: int):
348348

349349

350350
class DeepGemmFusedMoE(CutlassFusedMoE):
351-
"""
352-
Python Flow of Fused Mixture of Experts (MoE) Layer.
351+
"""DeepGEMM flow of fused mixture of experts (MoE) Layer.
353352
354353
Args:
355354
num_experts (int): Number of experts in the MoE layer.
@@ -360,11 +359,6 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
360359
dtype (Optional[torch.dtype]): Data type for the weights.
361360
reduce_results (bool): Whether to reduce the results across devices.
362361
model_config (ModelConfig): Configuration object for the model.
363-
364-
This backend is composed of multiple custom ops:
365-
1. moe_permute_op: permute the input tensor and the expert selected tensor.
366-
2. cute_dsl_fp8_group_blockwise_gemm_ref: a reference implementation of the cute_dsl_fp8_group_blockwise_gemm.
367-
3. moe_finalize_scale_op: finalize the scale of the output tensor.
368362
"""
369363

370364
# To reuse pytorch memory segments allocated during graph capture.

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,35 +1828,6 @@ def setup_quant_scales(self, module: torch.nn.Module):
18281828
fc2_global=module.fc2_alpha,
18291829
)
18301830

1831-
def post_load_weights(self, module: torch.nn.Module):
1832-
super().post_load_weights(module)
1833-
if module.moe_backend == "CUTEDSL":
1834-
# Interleave FC1 weight and scales for GEMM1 + SwiGLU fusion.
1835-
w3_w1_weight = module.w3_w1_weight.data.view(float4_e2m1x2)
1836-
m = w3_w1_weight.size(1)
1837-
n = w3_w1_weight.size(2) * 2
1838-
w3_w1_weight_interleaved = interleave_linear_and_gate(w3_w1_weight,
1839-
group_size=64,
1840-
dim=1)
1841-
w3_w1_weight_interleaved = w3_w1_weight_interleaved.view(
1842-
module.w3_w1_weight.data.dtype)
1843-
module.w3_w1_weight.data.copy_(w3_w1_weight_interleaved)
1844-
1845-
w3_w1_weight_scale = module.quant_scales.fc1_weight_block.data.view(
1846-
float4_sf_dtype)
1847-
w3_w1_weight_scale_unswizzled = unswizzle_sf(
1848-
w3_w1_weight_scale, m, n).view(-1, m,
1849-
n // module.scaling_vector_size)
1850-
w3_w1_weight_scale_unswizzled_interleaved = interleave_linear_and_gate(
1851-
w3_w1_weight_scale_unswizzled, group_size=64, dim=1)
1852-
w3_w1_weight_scale_interleaved = swizzle_sf(
1853-
w3_w1_weight_scale_unswizzled_interleaved, m,
1854-
n).view(-1, m, n // module.scaling_vector_size)
1855-
w3_w1_weight_scale_interleaved = w3_w1_weight_scale_interleaved.view(
1856-
module.quant_scales.fc1_weight_block.data.dtype)
1857-
module.quant_scales.fc1_weight_block.data.copy_(
1858-
w3_w1_weight_scale_interleaved)
1859-
18601831

18611832
class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
18621833
weight_dtype = FUSED_MOE_NVFP4_WEIGHT_DTYPE
@@ -1935,6 +1906,38 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
19351906
dst_w2_weight_scale.copy_(dst_w2_weight_scale_interleaved)
19361907

19371908

1909+
class NVFP4CuteDslFusedMoEMethod(NVFP4CutlassFusedMoEMethod):
1910+
1911+
def post_load_weights(self, module: torch.nn.Module):
1912+
super().post_load_weights(module)
1913+
1914+
# Interleave FC1 weight and scales for GEMM1 + SwiGLU fusion.
1915+
w3_w1_weight = module.w3_w1_weight.data.view(float4_e2m1x2)
1916+
m = w3_w1_weight.size(1)
1917+
n = w3_w1_weight.size(2) * 2
1918+
w3_w1_weight_interleaved = interleave_linear_and_gate(w3_w1_weight,
1919+
group_size=64,
1920+
dim=1)
1921+
w3_w1_weight_interleaved = w3_w1_weight_interleaved.view(
1922+
module.w3_w1_weight.data.dtype)
1923+
module.w3_w1_weight.data.copy_(w3_w1_weight_interleaved)
1924+
1925+
w3_w1_weight_scale = module.quant_scales.fc1_weight_block.data.view(
1926+
float4_sf_dtype)
1927+
w3_w1_weight_scale_unswizzled = unswizzle_sf(
1928+
w3_w1_weight_scale, m, n).view(-1, m,
1929+
n // module.scaling_vector_size)
1930+
w3_w1_weight_scale_unswizzled_interleaved = interleave_linear_and_gate(
1931+
w3_w1_weight_scale_unswizzled, group_size=64, dim=1)
1932+
w3_w1_weight_scale_interleaved = swizzle_sf(
1933+
w3_w1_weight_scale_unswizzled_interleaved, m,
1934+
n).view(-1, m, n // module.scaling_vector_size)
1935+
w3_w1_weight_scale_interleaved = w3_w1_weight_scale_interleaved.view(
1936+
module.quant_scales.fc1_weight_block.data.dtype)
1937+
module.quant_scales.fc1_weight_block.data.copy_(
1938+
w3_w1_weight_scale_interleaved)
1939+
1940+
19381941
class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
19391942
weight_dtype = float4_sf_dtype
19401943
block_scales_dtype = torch.float8_e4m3fn

0 commit comments

Comments
 (0)