@@ -228,6 +228,8 @@ def torch_quant_fp8_moe(
228228 w1_weight_scale : List [torch .Tensor ],
229229 w2_weight_scale : List [torch .Tensor ],
230230 w3_weight_scale : List [torch .Tensor ],
231+ mlp_style : str = "gated_mlp" , # "gated_mlp" (default) or "mlp"
232+ act_fn : str = "silu" , # silu or relu2
231233) -> torch .Tensor :
232234 """
233235 FP8 MoE op using quantized linear operations.
@@ -239,40 +241,91 @@ def torch_quant_fp8_moe(
239241 x: Input tensor of shape (B, H) or (B, S, H).
240242 selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
241243 routing_weights: Tensor of normalized routing weights.
242- w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops.
244+ w1_weight:
245+ List of per-expert weight tensors:
246+ • mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection.
247+ • mlp_style=="mlp": W_up with shape (I, H) — up projection.
248+ w2_weight:
249+ List of per-expert weight tensors:
250+ • gated_mlp: W2 with shape (H, I) — down projection.
251+ • mlp: W_down with shape (H, I) — down projection.
252+ w3_weight:
253+ List of per-expert weight tensors:
254+ • gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP.
255+ • mlp: pass an empty list []; ignored.
243256 w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors for the corresponding ops.
244257 w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors for the corresponding ops.
245-
258+ mlp_style:
259+ Selects the per-expert MLP computation:
260+ • "gated_mlp" (default, Mixtral/DeepSeek-style):
261+ y = W2( act(W1 x) * (W3 x) )
262+ • "mlp" (NemotronH-style 2-layer MLP):
263+ y = W_down( act(W_up x) )
264+ act_fn:
265+ Elementwise activation applied inside the expert MLP.
266+ Supported: "silu" (default), "relu2" (ReLU then square).
246267 """
247268
248- def make_fp8_mlp (i ):
249- def mlp (inp ):
250- gate_out = torch .ops .auto_deploy .torch_quant_fp8_linear (
251- inp ,
252- w1_weight [i ],
253- bias = None ,
254- input_scale = w1_input_scale [i ],
255- weight_scale = w1_weight_scale [i ],
256- )
257- up_out = torch .ops .auto_deploy .torch_quant_fp8_linear (
258- inp ,
259- w3_weight [i ],
260- bias = None ,
261- input_scale = w3_input_scale [i ],
262- weight_scale = w3_weight_scale [i ],
263- )
264- prod = F .silu (gate_out ) * up_out
265- return torch .ops .auto_deploy .torch_quant_fp8_linear (
266- prod ,
267- w2_weight [i ],
268- bias = None ,
269- input_scale = w2_input_scale [i ],
270- weight_scale = w2_weight_scale [i ],
271- )
272-
273- return mlp
274-
275- mlps = [make_fp8_mlp (i ) for i in range (len (w1_weight ))]
269+ act_fn = _resolve_activation (act_fn )
270+ style = mlp_style .lower ()
271+
272+ if style == "gated_mlp" :
273+
274+ def make_fp8_mlp (i ):
275+ def mlp (inp ):
276+ gate_out = torch .ops .auto_deploy .torch_quant_fp8_linear (
277+ inp ,
278+ w1_weight [i ],
279+ bias = None ,
280+ input_scale = w1_input_scale [i ],
281+ weight_scale = w1_weight_scale [i ],
282+ )
283+ up_out = torch .ops .auto_deploy .torch_quant_fp8_linear (
284+ inp ,
285+ w3_weight [i ],
286+ bias = None ,
287+ input_scale = w3_input_scale [i ],
288+ weight_scale = w3_weight_scale [i ],
289+ )
290+ prod = act_fn (gate_out ) * up_out
291+ return torch .ops .auto_deploy .torch_quant_fp8_linear (
292+ prod ,
293+ w2_weight [i ],
294+ bias = None ,
295+ input_scale = w2_input_scale [i ],
296+ weight_scale = w2_weight_scale [i ],
297+ )
298+
299+ return mlp
300+
301+ mlps = [make_fp8_mlp (i ) for i in range (len (w1_weight ))]
302+
303+ elif style == "mlp" :
304+
305+ def make_fp8_mlp (i ):
306+ def mlp (inp ):
307+ up_out = torch .ops .auto_deploy .torch_quant_fp8_linear (
308+ inp ,
309+ w1_weight [i ],
310+ bias = None ,
311+ input_scale = w1_input_scale [i ],
312+ weight_scale = w1_weight_scale [i ],
313+ )
314+ return torch .ops .auto_deploy .torch_quant_fp8_linear (
315+ act_fn (up_out ),
316+ w2_weight [i ],
317+ bias = None ,
318+ input_scale = w2_input_scale [i ],
319+ weight_scale = w2_weight_scale [i ],
320+ )
321+
322+ return mlp
323+
324+ mlps = [make_fp8_mlp (i ) for i in range (len (w1_weight ))]
325+
326+ else :
327+ raise ValueError (f"Unknown mlp_style '{ mlp_style } '. Use 'gated_mlp' or 'mlp'." )
328+
276329 return _template_moe (x , selected_experts , routing_weights , mlps )
277330
278331
@@ -290,6 +343,8 @@ def torch_quant_fp8_moe_fake(
290343 w1_weight_scale : List [torch .Tensor ],
291344 w2_weight_scale : List [torch .Tensor ],
292345 w3_weight_scale : List [torch .Tensor ],
346+ mlp_style : str = "gated_mlp" ,
347+ act_fn : str = "silu" ,
293348) -> torch .Tensor :
294349 return torch .empty_like (x )
295350
@@ -311,6 +366,8 @@ def torch_quant_nvfp4_moe(
311366 w1_alpha : List [torch .Tensor ],
312367 w2_alpha : List [torch .Tensor ],
313368 w3_alpha : List [torch .Tensor ],
369+ mlp_style : str = "gated_mlp" , # "gated_mlp" (default) or "mlp"
370+ act_fn : str = "silu" , # silu or relu2
314371) -> torch .Tensor :
315372 """
316373 FP4 MoE op using quantized linear operations.
@@ -322,45 +379,101 @@ def torch_quant_nvfp4_moe(
322379 x: Input tensor of shape (B, H) or (B, S, H).
323380 selected_experts: Tensor (B, TOP_K) or (B*S, TOP_K) containing expert indices.
324381 routing_weights: Tensor of normalized routing weights.
325- w1_weight, w2_weight, w3_weight: Lists of pre-quantized weight tensors for the three linear ops.
382+ w1_weight:
383+ List of per-expert weight tensors:
384+ • mlp_style=="gated_mlp": W1 with shape (I, H) — "gate" projection.
385+ • mlp_style=="mlp": W_up with shape (I, H) — up projection.
386+ w2_weight:
387+ List of per-expert weight tensors:
388+ • gated_mlp: W2 with shape (H, I) — down projection.
389+ • mlp: W_down with shape (H, I) — down projection.
390+ w3_weight:
391+ List of per-expert weight tensors:
392+ • gated_mlp: W3 with shape (I, H) — "up" (second) projection in gated MLP.
393+ • mlp: pass an empty list []; ignored.
326394 w1_input_scale, w2_input_scale, w3_input_scale: Lists of input scale tensors.
327395 w1_weight_scale, w2_weight_scale, w3_weight_scale: Lists of weight scale tensors.
328396 w1_alpha, w2_alpha, w3_alpha: Lists of alpha scale tensors for FP4 quantization.
397+ mlp_style:
398+ Selects the per-expert MLP computation:
399+ • "gated_mlp" (default, Mixtral/DeepSeek-style):
400+ y = W2( act(W1 x) * (W3 x) )
401+ • "mlp" (NemotronH-style 2-layer MLP):
402+ y = W_down( act(W_up x) )
403+ act_fn:
404+ Elementwise activation applied inside the expert MLP.
405+ Supported: "silu" (default), "relu2" (ReLU then square).
329406 """
330407
331- def make_fp4_mlp (i ):
332- def mlp (inp ):
333- if inp .shape [0 ] == 0 :
334- return torch .zeros_like (inp )
335- gate_out = torch .ops .auto_deploy .torch_quant_nvfp4_linear (
336- inp ,
337- w1_weight [i ],
338- bias = None ,
339- input_scale = w1_input_scale [i ],
340- weight_scale = w1_weight_scale [i ],
341- alpha = w1_alpha [i ],
342- )
343- up_out = torch .ops .auto_deploy .torch_quant_nvfp4_linear (
344- inp ,
345- w3_weight [i ],
346- bias = None ,
347- input_scale = w3_input_scale [i ],
348- weight_scale = w3_weight_scale [i ],
349- alpha = w3_alpha [i ],
350- )
351- prod = F .silu (gate_out ) * up_out
352- return torch .ops .auto_deploy .torch_quant_nvfp4_linear (
353- prod ,
354- w2_weight [i ],
355- bias = None ,
356- input_scale = w2_input_scale [i ],
357- weight_scale = w2_weight_scale [i ],
358- alpha = w2_alpha [i ],
359- )
360-
361- return mlp
362-
363- mlps = [make_fp4_mlp (i ) for i in range (len (w1_weight ))]
408+ act_fn = _resolve_activation (act_fn )
409+ style = mlp_style .lower ()
410+
411+ if style == "gated_mlp" :
412+
413+ def make_fp4_mlp (i ):
414+ def mlp (inp ):
415+ if inp .shape [0 ] == 0 :
416+ return torch .zeros_like (inp )
417+ gate_out = torch .ops .auto_deploy .torch_quant_nvfp4_linear (
418+ inp ,
419+ w1_weight [i ],
420+ bias = None ,
421+ input_scale = w1_input_scale [i ],
422+ weight_scale = w1_weight_scale [i ],
423+ alpha = w1_alpha [i ],
424+ )
425+ up_out = torch .ops .auto_deploy .torch_quant_nvfp4_linear (
426+ inp ,
427+ w3_weight [i ],
428+ bias = None ,
429+ input_scale = w3_input_scale [i ],
430+ weight_scale = w3_weight_scale [i ],
431+ alpha = w3_alpha [i ],
432+ )
433+ prod = act_fn (gate_out ) * up_out
434+ return torch .ops .auto_deploy .torch_quant_nvfp4_linear (
435+ prod ,
436+ w2_weight [i ],
437+ bias = None ,
438+ input_scale = w2_input_scale [i ],
439+ weight_scale = w2_weight_scale [i ],
440+ alpha = w2_alpha [i ],
441+ )
442+
443+ return mlp
444+
445+ mlps = [make_fp4_mlp (i ) for i in range (len (w1_weight ))]
446+
447+ elif style == "mlp" :
448+
449+ def make_fp4_mlp (i ):
450+ def mlp (inp ):
451+ if inp .shape [0 ] == 0 :
452+ return torch .zeros_like (inp )
453+ up_out = torch .ops .auto_deploy .torch_quant_nvfp4_linear (
454+ inp ,
455+ w1_weight [i ],
456+ bias = None ,
457+ input_scale = w1_input_scale [i ],
458+ weight_scale = w1_weight_scale [i ],
459+ alpha = w1_alpha [i ],
460+ )
461+ return torch .ops .auto_deploy .torch_quant_nvfp4_linear (
462+ act_fn (up_out ),
463+ w2_weight [i ],
464+ bias = None ,
465+ input_scale = w2_input_scale [i ],
466+ weight_scale = w2_weight_scale [i ],
467+ alpha = w2_alpha [i ],
468+ )
469+
470+ return mlp
471+
472+ mlps = [make_fp4_mlp (i ) for i in range (len (w1_weight ))]
473+
474+ else :
475+ raise ValueError (f"Unknown mlp_style '{ mlp_style } '. Use 'gated_mlp' or 'mlp'." )
476+
364477 return _template_moe (x , selected_experts , routing_weights , mlps )
365478
366479
@@ -381,6 +494,8 @@ def torch_quant_nvfp4_moe_fake(
381494 w1_alpha : List [torch .Tensor ],
382495 w2_alpha : List [torch .Tensor ],
383496 w3_alpha : List [torch .Tensor ],
497+ mlp_style : str = "gated_mlp" ,
498+ act_fn : str = "silu" ,
384499) -> torch .Tensor :
385500 return torch .empty_like (x )
386501
0 commit comments