Skip to content

Commit a6d20f6

Browse files
nvchenghaozsuyogguptaFridah-nv
authored
[None][feat] AutoDeploy: Add FP8 MOE for Nemotron (#8599)
Signed-off-by: Chenghao Zhang <[email protected]> Signed-off-by: Fridah-nv <[email protected]> Signed-off-by: nvchenghaoz <[email protected]> Co-authored-by: Suyog Gupta <[email protected]> Co-authored-by: Fridah-nv <[email protected]>
1 parent 95be56e commit a6d20f6

File tree

7 files changed

+902
-179
lines changed

7 files changed

+902
-179
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ transforms:
116116
fuse_moe:
117117
stage: post_load_fusion
118118
enabled: true
119+
fuse_fp8_moe:
120+
stage: post_load_fusion
121+
enabled: true
119122
fuse_allreduce_residual_rmsnorm:
120123
stage: post_load_fusion
121124
# TODO (lucaslie): add backend selection as part of configurable inference optimizers

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py

Lines changed: 179 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def torch_quant_fp8_moe(
228228
w1_weight_scale: List[torch.Tensor],
229229
w2_weight_scale: List[torch.Tensor],
230230
w3_weight_scale: List[torch.Tensor],
231+
mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp"
232+
act_fn: str = "silu", # silu or relu2
231233
) -> torch.Tensor:
232234
"""
233235
FP8 MoE op using quantized linear operations.
@@ -239,40 +241,91 @@ def torch_quant_fp8_moe(
239241
x: Input tensor of shape (B, H) or (B, S, H).
240242
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
241243
routing_weights: Tensor of normalized routing weights.
242-
w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops.
244+
w1_weight:
245+
List of per-expert weight tensors:
246+
• mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection.
247+
• mlp_style=="mlp": W_up with shape (I, H) — up projection.
248+
w2_weight:
249+
List of per-expert weight tensors:
250+
• gated_mlp: W2 with shape (H, I) — down projection.
251+
• mlp: W_down with shape (H, I) — down projection.
252+
w3_weight:
253+
List of per-expert weight tensors:
254+
• gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP.
255+
• mlp: pass an empty list []; ignored.
243256
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors for the corresponding ops.
244257
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors for the corresponding ops.
245-
258+
mlp_style:
259+
Selects the per-expert MLP computation:
260+
• "gated_mlp" (default, Mixtral/DeepSeek-style):
261+
y = W2( act(W1 x) * (W3 x) )
262+
• "mlp" (NemotronH-style 2-layer MLP):
263+
y = W_down( act(W_up x) )
264+
act_fn:
265+
Elementwise activation applied inside the expert MLP.
266+
Supported: "silu" (default), "relu2" (ReLU then square).
246267
"""
247268

248-
def make_fp8_mlp(i):
249-
def mlp(inp):
250-
gate_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
251-
inp,
252-
w1_weight[i],
253-
bias=None,
254-
input_scale=w1_input_scale[i],
255-
weight_scale=w1_weight_scale[i],
256-
)
257-
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
258-
inp,
259-
w3_weight[i],
260-
bias=None,
261-
input_scale=w3_input_scale[i],
262-
weight_scale=w3_weight_scale[i],
263-
)
264-
prod = F.silu(gate_out) * up_out
265-
return torch.ops.auto_deploy.torch_quant_fp8_linear(
266-
prod,
267-
w2_weight[i],
268-
bias=None,
269-
input_scale=w2_input_scale[i],
270-
weight_scale=w2_weight_scale[i],
271-
)
272-
273-
return mlp
274-
275-
mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
269+
act_fn = _resolve_activation(act_fn)
270+
style = mlp_style.lower()
271+
272+
if style == "gated_mlp":
273+
274+
def make_fp8_mlp(i):
275+
def mlp(inp):
276+
gate_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
277+
inp,
278+
w1_weight[i],
279+
bias=None,
280+
input_scale=w1_input_scale[i],
281+
weight_scale=w1_weight_scale[i],
282+
)
283+
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
284+
inp,
285+
w3_weight[i],
286+
bias=None,
287+
input_scale=w3_input_scale[i],
288+
weight_scale=w3_weight_scale[i],
289+
)
290+
prod = act_fn(gate_out) * up_out
291+
return torch.ops.auto_deploy.torch_quant_fp8_linear(
292+
prod,
293+
w2_weight[i],
294+
bias=None,
295+
input_scale=w2_input_scale[i],
296+
weight_scale=w2_weight_scale[i],
297+
)
298+
299+
return mlp
300+
301+
mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
302+
303+
elif style == "mlp":
304+
305+
def make_fp8_mlp(i):
306+
def mlp(inp):
307+
up_out = torch.ops.auto_deploy.torch_quant_fp8_linear(
308+
inp,
309+
w1_weight[i],
310+
bias=None,
311+
input_scale=w1_input_scale[i],
312+
weight_scale=w1_weight_scale[i],
313+
)
314+
return torch.ops.auto_deploy.torch_quant_fp8_linear(
315+
act_fn(up_out),
316+
w2_weight[i],
317+
bias=None,
318+
input_scale=w2_input_scale[i],
319+
weight_scale=w2_weight_scale[i],
320+
)
321+
322+
return mlp
323+
324+
mlps = [make_fp8_mlp(i) for i in range(len(w1_weight))]
325+
326+
else:
327+
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
328+
276329
return _template_moe(x, selected_experts, routing_weights, mlps)
277330

278331

@@ -290,6 +343,8 @@ def torch_quant_fp8_moe_fake(
290343
w1_weight_scale: List[torch.Tensor],
291344
w2_weight_scale: List[torch.Tensor],
292345
w3_weight_scale: List[torch.Tensor],
346+
mlp_style: str = "gated_mlp",
347+
act_fn: str = "silu",
293348
) -> torch.Tensor:
294349
return torch.empty_like(x)
295350

@@ -311,6 +366,8 @@ def torch_quant_nvfp4_moe(
311366
w1_alpha: List[torch.Tensor],
312367
w2_alpha: List[torch.Tensor],
313368
w3_alpha: List[torch.Tensor],
369+
mlp_style: str = "gated_mlp", # "gated_mlp" (default) or "mlp"
370+
act_fn: str = "silu", # silu or relu2
314371
) -> torch.Tensor:
315372
"""
316373
FP4 MoE op using quantized linear operations.
@@ -322,45 +379,101 @@ def torch_quant_nvfp4_moe(
322379
x: Input tensor of shape (B, H) or (B, S, H).
323380
selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
324381
routing_weights: Tensor of normalized routing weights.
325-
w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops.
382+
w1_weight:
383+
List of per-expert weight tensors:
384+
• mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection.
385+
• mlp_style=="mlp": W_up with shape (I, H) — up projection.
386+
w2_weight:
387+
List of per-expert weight tensors:
388+
• gated_mlp: W2 with shape (H, I) — down projection.
389+
• mlp: W_down with shape (H, I) — down projection.
390+
w3_weight:
391+
List of per-expert weight tensors:
392+
• gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP.
393+
• mlp: pass an empty list []; ignored.
326394
w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors.
327395
w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors.
328396
w1_alpha, w2_alpha, w3_alpha: Lists of alpha scale tensors for FP4 quantization.
397+
mlp_style:
398+
Selects the per-expert MLP computation:
399+
• "gated_mlp" (default, Mixtral/DeepSeek-style):
400+
y = W2( act(W1 x) * (W3 x) )
401+
• "mlp" (NemotronH-style 2-layer MLP):
402+
y = W_down( act(W_up x) )
403+
act_fn:
404+
Elementwise activation applied inside the expert MLP.
405+
Supported: "silu" (default), "relu2" (ReLU then square).
329406
"""
330407

331-
def make_fp4_mlp(i):
332-
def mlp(inp):
333-
if inp.shape[0] == 0:
334-
return torch.zeros_like(inp)
335-
gate_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
336-
inp,
337-
w1_weight[i],
338-
bias=None,
339-
input_scale=w1_input_scale[i],
340-
weight_scale=w1_weight_scale[i],
341-
alpha=w1_alpha[i],
342-
)
343-
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
344-
inp,
345-
w3_weight[i],
346-
bias=None,
347-
input_scale=w3_input_scale[i],
348-
weight_scale=w3_weight_scale[i],
349-
alpha=w3_alpha[i],
350-
)
351-
prod = F.silu(gate_out) * up_out
352-
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
353-
prod,
354-
w2_weight[i],
355-
bias=None,
356-
input_scale=w2_input_scale[i],
357-
weight_scale=w2_weight_scale[i],
358-
alpha=w2_alpha[i],
359-
)
360-
361-
return mlp
362-
363-
mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
408+
act_fn = _resolve_activation(act_fn)
409+
style = mlp_style.lower()
410+
411+
if style == "gated_mlp":
412+
413+
def make_fp4_mlp(i):
414+
def mlp(inp):
415+
if inp.shape[0] == 0:
416+
return torch.zeros_like(inp)
417+
gate_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
418+
inp,
419+
w1_weight[i],
420+
bias=None,
421+
input_scale=w1_input_scale[i],
422+
weight_scale=w1_weight_scale[i],
423+
alpha=w1_alpha[i],
424+
)
425+
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
426+
inp,
427+
w3_weight[i],
428+
bias=None,
429+
input_scale=w3_input_scale[i],
430+
weight_scale=w3_weight_scale[i],
431+
alpha=w3_alpha[i],
432+
)
433+
prod = act_fn(gate_out) * up_out
434+
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
435+
prod,
436+
w2_weight[i],
437+
bias=None,
438+
input_scale=w2_input_scale[i],
439+
weight_scale=w2_weight_scale[i],
440+
alpha=w2_alpha[i],
441+
)
442+
443+
return mlp
444+
445+
mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
446+
447+
elif style == "mlp":
448+
449+
def make_fp4_mlp(i):
450+
def mlp(inp):
451+
if inp.shape[0] == 0:
452+
return torch.zeros_like(inp)
453+
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
454+
inp,
455+
w1_weight[i],
456+
bias=None,
457+
input_scale=w1_input_scale[i],
458+
weight_scale=w1_weight_scale[i],
459+
alpha=w1_alpha[i],
460+
)
461+
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
462+
act_fn(up_out),
463+
w2_weight[i],
464+
bias=None,
465+
input_scale=w2_input_scale[i],
466+
weight_scale=w2_weight_scale[i],
467+
alpha=w2_alpha[i],
468+
)
469+
470+
return mlp
471+
472+
mlps = [make_fp4_mlp(i) for i in range(len(w1_weight))]
473+
474+
else:
475+
raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.")
476+
364477
return _template_moe(x, selected_experts, routing_weights, mlps)
365478

366479

@@ -381,6 +494,8 @@ def torch_quant_nvfp4_moe_fake(
381494
w1_alpha: List[torch.Tensor],
382495
w2_alpha: List[torch.Tensor],
383496
w3_alpha: List[torch.Tensor],
497+
mlp_style: str = "gated_mlp",
498+
act_fn: str = "silu",
384499
) -> torch.Tensor:
385500
return torch.empty_like(x)
386501

0 commit comments

Comments
 (0)