Skip to content

Commit a5a6e39

Browse files
committed
resmooth
Signed-off-by: weimingc <[email protected]>
1 parent 6dd1b87 commit a5a6e39

File tree

2 files changed

+193
-42
lines changed

2 files changed

+193
-42
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
478478

479479
if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"):
480480
return QUANTIZATION_NVFP4_AWQ
481+
if getattr(layer, "fused_with_prequant", False):
482+
return QUANTIZATION_NVFP4_AWQ
481483
assert input_quantizer is not None, (
482484
f"input_quantizer is None for {quantizer_attr_names}"
483485
)
@@ -937,7 +939,7 @@ def all_items_same(item_list):
937939

938940

939941
# TODO: make this more general instead of rule based
940-
def pattern_fuse_prequant(model: torch.nn.Module):
942+
def pattern_fuse_prequant(model: torch.nn.Module, fuse_mismatch_dim=False):
941943
"""Fuse pre_quant_scale to the linear weights.
942944
943945
For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that
@@ -951,10 +953,29 @@ def pattern_fuse_prequant(model: torch.nn.Module):
951953
the pre_quant_scale is averaged across the repeated head groups and then the
952954
o_proj's pre_quant_scale is updated to maintain mathematical equivalence.
953955
956+
Args:
957+
model: The model to fuse pre_quant_scale to.
958+
fuse_mismatch_dim: If True, fuse the pre_quant_scale even if dimension between pre_quant_scale
959+
and linear weights is not the same. This is useful for GQA/MQA models but may lead to accuracy
960+
drop.
961+
954962
Note:
955963
This is an experimental feature, and it might mess up the quantization errors
956964
of fused linear modules.
957965
"""
966+
# For MoE models, let's first resmooth the w1 and w3 in experts to get the average pre_quant_scale
967+
for _, module in model.named_modules():
968+
if (
969+
hasattr(module, "experts")
970+
and "Qwen3MoeSparseMoeBlock".lower() in type(module).__name__.lower()
971+
):
972+
linear_list = []
973+
linear_list.extend([getattr(expert, "up_proj") for expert in module.experts])
974+
linear_list.extend([getattr(expert, "gate_proj") for expert in module.experts])
975+
preprocess_linear_fusion(linear_list, resmooth_only=True)
976+
977+
# import pdb; pdb.set_trace()
978+
# Fuse pre_quant_scale to the linear weights
958979
for _, module in model.named_modules():
959980
for module_map in PQS_FUSE_MODULE_MAPPING:
960981
target_module_list = module_map[0]
@@ -967,52 +988,58 @@ def pattern_fuse_prequant(model: torch.nn.Module):
967988
):
968989
pre_quant_scale = linear_pqs_from.input_quantizer._pre_quant_scale
969990

970-
# for GQA/MQA models, we apply averaging to the pre_quant_scale
971-
if pre_quant_scale.numel() != linear_fuse_into.weight.shape[0]:
972-
if "attention" not in type(module).__name__.lower():
973-
continue
974-
else:
975-
config = module.config
976-
num_kv_heads = config.num_key_value_heads
977-
kv_head_dim = linear_fuse_into.weight.shape[0] // num_kv_heads
978-
n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim
979-
980-
# Reshape:(num_kv_heads, n_rep, kv_head_dim)
981-
averaged_scale = pre_quant_scale.view(
982-
num_kv_heads, n_rep, kv_head_dim
983-
).mean(dim=1)
984-
985-
# To update o_proj, we need to repeat back to original shape
986-
repeated_scale = (
987-
averaged_scale.unsqueeze(1)
988-
.expand(num_kv_heads, n_rep, kv_head_dim)
989-
.reshape(-1)
991+
# for GQA/MQA models, we apply averaging to the pre_quant_scale for shared head groups
992+
if pre_quant_scale.numel() != linear_fuse_into.weight.shape[-2]:
993+
if (
994+
not fuse_mismatch_dim
995+
or "attention" not in type(module).__name__.lower()
996+
):
997+
warn(
998+
f"Skipping pattern fuse prequant for {type(module).__name__}"
999+
f"pqs dim {pre_quant_scale.numel()} != out_ch dim {linear_fuse_into.weight.shape[-2]}"
9901000
)
1001+
continue
1002+
config = module.config
1003+
num_kv_heads = config.num_key_value_heads
1004+
kv_head_dim = linear_fuse_into.weight.shape[0] // num_kv_heads
1005+
n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim
1006+
1007+
# Reshape:(num_kv_heads, n_rep, kv_head_dim)
1008+
averaged_scale = pre_quant_scale.view(
1009+
num_kv_heads, n_rep, kv_head_dim
1010+
).mean(dim=1)
1011+
1012+
# To update o_proj, we need to repeat back to original shape
1013+
repeated_scale = (
1014+
averaged_scale.unsqueeze(1)
1015+
.expand(num_kv_heads, n_rep, kv_head_dim)
1016+
.reshape(-1)
1017+
)
9911018

992-
def _update_pre_quant_scale(module, new_pre_quant_scale):
993-
old_pre_quant_scale = module.input_quantizer._pre_quant_scale
994-
module.weight = nn.Parameter(
995-
module.weight
996-
* old_pre_quant_scale.to(
997-
dtype=module.weight.dtype, device=module.weight.device
998-
)
999-
/ new_pre_quant_scale.to(
1000-
dtype=module.weight.dtype, device=module.weight.device
1001-
)
1019+
def _update_pre_quant_scale(module, new_pre_quant_scale):
1020+
old_pre_quant_scale = module.input_quantizer._pre_quant_scale
1021+
module.weight = nn.Parameter(
1022+
module.weight
1023+
* old_pre_quant_scale.to(
1024+
dtype=module.weight.dtype, device=module.weight.device
1025+
)
1026+
/ new_pre_quant_scale.to(
1027+
dtype=module.weight.dtype, device=module.weight.device
10021028
)
1003-
module.input_quantizer.pre_quant_scale = new_pre_quant_scale
1029+
)
1030+
module.input_quantizer.pre_quant_scale = new_pre_quant_scale
10041031

1005-
# Redo weights collection
1006-
module.weight_quantizer.reset_amax()
1007-
enable_stats_collection(module.weight_quantizer)
1008-
module.weight_quantizer(module.weight)
1009-
finish_stats_collection(module.weight_quantizer)
1032+
# Redo weights collection
1033+
module.weight_quantizer.reset_amax()
1034+
enable_stats_collection(module.weight_quantizer)
1035+
module.weight_quantizer(module.weight)
1036+
finish_stats_collection(module.weight_quantizer)
10101037

1011-
# Update o_proj's pre_quant_scale
1012-
_update_pre_quant_scale(linear_pqs_from, repeated_scale)
1038+
# Update o_proj's pre_quant_scale
1039+
_update_pre_quant_scale(linear_pqs_from, repeated_scale)
10131040

1014-
# Use averaged scale (flattened) for v_proj fusion
1015-
pre_quant_scale = averaged_scale.reshape(-1)
1041+
# Use averaged scale (flattened) for v_proj fusion
1042+
pre_quant_scale = averaged_scale.reshape(-1)
10161043

10171044
# Fuse the pre_quant_scale to v_proj weight
10181045
linear_fuse_into.weight = torch.nn.Parameter(

tests/gpu/torch/export/test_quant_utils.py

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair):
7474
]
7575

7676
# Apply fusion
77-
pattern_fuse_prequant(model)
77+
pattern_fuse_prequant(model, fuse_mismatch_dim=True)
7878

7979
# Check if pre_quant_scale and fused_with_prequant flag are removed correctly
8080
for target_module_name in traget_module_name_list:
@@ -97,3 +97,127 @@ def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair):
9797
assert torch.allclose(
9898
output_before_fuse.logits, output_after_fuse.logits, rtol=1e-1, atol=5e-1
9999
), "Output should be the same before and after fusion"
100+
101+
102+
# TODO: add test for Qwen3MoeSparseMoeBlock MLP fusion
103+
104+
105+
@pytest.mark.parametrize(
106+
"quant_config",
107+
[
108+
mtq.INT4_AWQ_CFG,
109+
mtq.NVFP4_AWQ_LITE_CFG,
110+
],
111+
)
112+
def test_pattern_fuse_prequant_moe(quant_config):
113+
"""Test pattern_fuse_prequant on Qwen3 MoE sparse MLP."""
114+
pytest.importorskip("transformers", minversion="4.46.0")
115+
from transformers import Qwen3MoeConfig, Qwen3MoeForCausalLM
116+
117+
# Create a tiny Qwen3MoE model for testing
118+
config = Qwen3MoeConfig(
119+
hidden_size=128,
120+
intermediate_size=256,
121+
moe_intermediate_size=256,
122+
num_hidden_layers=2,
123+
num_attention_heads=4,
124+
num_key_value_heads=4,
125+
num_experts=4,
126+
num_experts_per_tok=2,
127+
max_position_embeddings=128,
128+
vocab_size=256,
129+
shared_expert_intermediate_size=256,
130+
)
131+
model = Qwen3MoeForCausalLM(config).to("cuda")
132+
133+
# Quantize the model
134+
dummy_input = torch.randint(0, 256, (1, 16), device="cuda")
135+
mtq.quantize(model, quant_config, lambda m: m(dummy_input))
136+
137+
# Collect MoE expert modules to verify (down_proj should be fused)
138+
moe_down_proj_modules = []
139+
moe_gate_proj_modules = []
140+
moe_up_proj_modules = []
141+
for name, module in model.named_modules():
142+
if "mlp" in name and "experts" in name:
143+
if "gate_proj" in name and not any(x in name for x in ["weight", "quantizer"]):
144+
moe_gate_proj_modules.append((name, module))
145+
elif "down_proj" in name and not any(x in name for x in ["weight", "quantizer"]):
146+
moe_down_proj_modules.append((name, module))
147+
elif "up_proj" in name and not any(x in name for x in ["weight", "quantizer"]):
148+
moe_up_proj_modules.append((name, module))
149+
150+
# Verify experts have pre_quant_scale before fusion
151+
for name, module in moe_gate_proj_modules:
152+
if hasattr(module, "input_quantizer"):
153+
assert hasattr(module.input_quantizer, "_pre_quant_scale"), (
154+
f"{name}: gate_proj should have pre_quant_scale before fusion"
155+
)
156+
157+
for name, module in moe_up_proj_modules:
158+
if hasattr(module, "input_quantizer"):
159+
assert hasattr(module.input_quantizer, "_pre_quant_scale"), (
160+
f"{name}: up_proj should have pre_quant_scale before fusion"
161+
)
162+
163+
for name, module in moe_down_proj_modules:
164+
if hasattr(module, "input_quantizer"):
165+
assert hasattr(module.input_quantizer, "_pre_quant_scale"), (
166+
f"{name}: down_proj should have pre_quant_scale before fusion"
167+
)
168+
169+
# Run forward pass before fusion
170+
model.eval()
171+
with torch.no_grad():
172+
output_before_fuse = model(dummy_input)
173+
174+
# Apply fusion (fuse_mismatch_dim only needed for GQA/MQA attention, not for MLP)
175+
pattern_fuse_prequant(model)
176+
177+
# Check if down_proj's pre_quant_scale was removed and fused into up_proj
178+
for name, module in moe_down_proj_modules:
179+
if hasattr(module, "input_quantizer"):
180+
# Verify pre_quant_scale was removed from down_proj
181+
assert not hasattr(module.input_quantizer, "_pre_quant_scale"), (
182+
f"{name}: down_proj pre_quant_scale should be removed after fusion"
183+
)
184+
# Verify fused_with_prequant flag was set
185+
assert hasattr(module, "fused_with_prequant") and module.fused_with_prequant, (
186+
f"{name}: down_proj should have fused_with_prequant flag set"
187+
)
188+
189+
# Verify that gate_proj and up_proj still have pre_quant_scale and are resmoothed
190+
for name, module in model.named_modules():
191+
if "Qwen3MoeSparseMoeBlock".lower() in type(module).__name__.lower():
192+
first_gate_scale = getattr(
193+
getattr(module, "experts")[0], "gate_proj"
194+
).input_quantizer._pre_quant_scale
195+
first_up_scale = getattr(
196+
getattr(module, "experts")[0], "up_proj"
197+
).input_quantizer._pre_quant_scale
198+
199+
# gate_proj and up_proj should have the same scale after resmoothing
200+
assert torch.allclose(first_gate_scale, first_up_scale), (
201+
"gate_proj and up_proj should have the same pre_quant_scale after resmoothing"
202+
)
203+
204+
# All experts should have the same gate_proj and up_proj scales
205+
for i, expert in enumerate(getattr(module, "experts")):
206+
gate_scale = getattr(expert, "gate_proj").input_quantizer._pre_quant_scale
207+
up_scale = getattr(expert, "up_proj").input_quantizer._pre_quant_scale
208+
209+
assert torch.allclose(gate_scale, first_gate_scale), (
210+
f"Expert {i} gate_proj scale should match expert 0"
211+
)
212+
assert torch.allclose(up_scale, first_up_scale), (
213+
f"Expert {i} up_proj scale should match expert 0"
214+
)
215+
216+
# Verify output is close to the original output
217+
with torch.no_grad():
218+
output_after_fuse = model(dummy_input)
219+
220+
# There will be some difference due to quantization errors after pre_quant_scale fusion
221+
assert torch.allclose(
222+
output_before_fuse.logits, output_after_fuse.logits, rtol=1e-1, atol=5e-1
223+
), "Output should be similar before and after Qwen3 MoE fusion"

0 commit comments

Comments
 (0)