@@ -478,6 +478,8 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
478478
479479 if input_quantizer is not None and hasattr (input_quantizer , "_pre_quant_scale" ):
480480 return QUANTIZATION_NVFP4_AWQ
481+ if getattr (layer , "fused_with_prequant" , False ):
482+ return QUANTIZATION_NVFP4_AWQ
481483 assert input_quantizer is not None , (
482484 f"input_quantizer is None for { quantizer_attr_names } "
483485 )
@@ -937,7 +939,7 @@ def all_items_same(item_list):
937939
938940
939941# TODO: make this more general instead of rule based
940- def pattern_fuse_prequant (model : torch .nn .Module ):
942+ def pattern_fuse_prequant (model : torch .nn .Module , fuse_mismatch_dim = False ):
941943 """Fuse pre_quant_scale to the linear weights.
942944
943945 For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that
@@ -951,10 +953,29 @@ def pattern_fuse_prequant(model: torch.nn.Module):
951953 the pre_quant_scale is averaged across the repeated head groups and then the
952954 o_proj's pre_quant_scale is updated to maintain mathematical equivalence.
953955
956+ Args:
957+ model: The model to fuse pre_quant_scale to.
958+ fuse_mismatch_dim: If True, fuse the pre_quant_scale even if dimension between pre_quant_scale
959+ and linear weights is not the same. This is useful for GQA/MQA models but may lead to accuracy
960+ drop.
961+
954962 Note:
955963 This is an experimental feature, and it might mess up the quantization errors
956964 of fused linear modules.
957965 """
966+ # For MoE models, let's first resmooth the w1 and w3 in experts to get the average pre_quant_scale
967+ for _ , module in model .named_modules ():
968+ if (
969+ hasattr (module , "experts" )
970+ and "Qwen3MoeSparseMoeBlock" .lower () in type (module ).__name__ .lower ()
971+ ):
972+ linear_list = []
973+ linear_list .extend ([getattr (expert , "up_proj" ) for expert in module .experts ])
974+ linear_list .extend ([getattr (expert , "gate_proj" ) for expert in module .experts ])
975+ preprocess_linear_fusion (linear_list , resmooth_only = True )
976+
977+ # import pdb; pdb.set_trace()
978+ # Fuse pre_quant_scale to the linear weights
958979 for _ , module in model .named_modules ():
959980 for module_map in PQS_FUSE_MODULE_MAPPING :
960981 target_module_list = module_map [0 ]
@@ -967,52 +988,58 @@ def pattern_fuse_prequant(model: torch.nn.Module):
967988 ):
968989 pre_quant_scale = linear_pqs_from .input_quantizer ._pre_quant_scale
969990
970- # for GQA/MQA models, we apply averaging to the pre_quant_scale
971- if pre_quant_scale .numel () != linear_fuse_into .weight .shape [0 ]:
972- if "attention" not in type (module ).__name__ .lower ():
973- continue
974- else :
975- config = module .config
976- num_kv_heads = config .num_key_value_heads
977- kv_head_dim = linear_fuse_into .weight .shape [0 ] // num_kv_heads
978- n_rep = pre_quant_scale .numel () // num_kv_heads // kv_head_dim
979-
980- # Reshape:(num_kv_heads, n_rep, kv_head_dim)
981- averaged_scale = pre_quant_scale .view (
982- num_kv_heads , n_rep , kv_head_dim
983- ).mean (dim = 1 )
984-
985- # To update o_proj, we need to repeat back to original shape
986- repeated_scale = (
987- averaged_scale .unsqueeze (1 )
988- .expand (num_kv_heads , n_rep , kv_head_dim )
989- .reshape (- 1 )
991+ # for GQA/MQA models, we apply averaging to the pre_quant_scale for shared head groups
992+ if pre_quant_scale .numel () != linear_fuse_into .weight .shape [- 2 ]:
993+ if (
994+ not fuse_mismatch_dim
995+ or "attention" not in type (module ).__name__ .lower ()
996+ ):
997+ warn (
998+ f"Skipping pattern fuse prequant for { type (module ).__name__ } "
999+ f"pqs dim { pre_quant_scale .numel ()} != out_ch dim { linear_fuse_into .weight .shape [- 2 ]} "
9901000 )
1001+ continue
1002+ config = module .config
1003+ num_kv_heads = config .num_key_value_heads
1004+ kv_head_dim = linear_fuse_into .weight .shape [0 ] // num_kv_heads
1005+ n_rep = pre_quant_scale .numel () // num_kv_heads // kv_head_dim
1006+
1007+ # Reshape:(num_kv_heads, n_rep, kv_head_dim)
1008+ averaged_scale = pre_quant_scale .view (
1009+ num_kv_heads , n_rep , kv_head_dim
1010+ ).mean (dim = 1 )
1011+
1012+ # To update o_proj, we need to repeat back to original shape
1013+ repeated_scale = (
1014+ averaged_scale .unsqueeze (1 )
1015+ .expand (num_kv_heads , n_rep , kv_head_dim )
1016+ .reshape (- 1 )
1017+ )
9911018
992- def _update_pre_quant_scale (module , new_pre_quant_scale ):
993- old_pre_quant_scale = module .input_quantizer ._pre_quant_scale
994- module .weight = nn .Parameter (
995- module .weight
996- * old_pre_quant_scale .to (
997- dtype = module .weight .dtype , device = module .weight .device
998- )
999- / new_pre_quant_scale .to (
1000- dtype = module .weight .dtype , device = module .weight .device
1001- )
1019+ def _update_pre_quant_scale (module , new_pre_quant_scale ):
1020+ old_pre_quant_scale = module .input_quantizer ._pre_quant_scale
1021+ module .weight = nn .Parameter (
1022+ module .weight
1023+ * old_pre_quant_scale .to (
1024+ dtype = module .weight .dtype , device = module .weight .device
1025+ )
1026+ / new_pre_quant_scale .to (
1027+ dtype = module .weight .dtype , device = module .weight .device
10021028 )
1003- module .input_quantizer .pre_quant_scale = new_pre_quant_scale
1029+ )
1030+ module .input_quantizer .pre_quant_scale = new_pre_quant_scale
10041031
1005- # Redo weights collection
1006- module .weight_quantizer .reset_amax ()
1007- enable_stats_collection (module .weight_quantizer )
1008- module .weight_quantizer (module .weight )
1009- finish_stats_collection (module .weight_quantizer )
1032+ # Redo weights collection
1033+ module .weight_quantizer .reset_amax ()
1034+ enable_stats_collection (module .weight_quantizer )
1035+ module .weight_quantizer (module .weight )
1036+ finish_stats_collection (module .weight_quantizer )
10101037
1011- # Update o_proj's pre_quant_scale
1012- _update_pre_quant_scale (linear_pqs_from , repeated_scale )
1038+ # Update o_proj's pre_quant_scale
1039+ _update_pre_quant_scale (linear_pqs_from , repeated_scale )
10131040
1014- # Use averaged scale (flattened) for v_proj fusion
1015- pre_quant_scale = averaged_scale .reshape (- 1 )
1041+ # Use averaged scale (flattened) for v_proj fusion
1042+ pre_quant_scale = averaged_scale .reshape (- 1 )
10161043
10171044 # Fuse the pre_quant_scale to v_proj weight
10181045 linear_fuse_into .weight = torch .nn .Parameter (
0 commit comments