From cf3a10781d90f516e2508b0ec6904599740f076f Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Mon, 3 Nov 2025 23:47:32 +0000 Subject: [PATCH] mamba moe export fixes Signed-off-by: Jennifer Chen --- .../torch/export/plugins/mcore_nemotron.py | 28 +++++++++++++ .../torch/export/plugins/megatron_importer.py | 20 +++++++-- .../torch/export/unified_export_megatron.py | 41 +++++++++++++++---- .../torch/quantization/plugins/huggingface.py | 23 +++++++---- .../torch/quantization/plugins/megatron.py | 34 ++++++++++++++- 5 files changed, 127 insertions(+), 19 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 752826bbc..5fdb8ba1b 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -17,8 +17,10 @@ """Custom mapping from Nemotron Hugging Face models to Megatron Core models.""" from .mcore_custom import ( + COL_ETP, COL_TP, REPLICATE, + ROW_ETP, ROW_TP, CustomModuleMapping, NameRemapping, @@ -63,6 +65,22 @@ "pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE), "linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj.", COL_TP), "linear_fc2": NameRemapping("backbone.layers.{}.mixer.down_proj.", ROW_TP), + # MoE + "router": NameRemapping( + "backbone.layers.{}.mixer.gate.", {"mapping": {"expert_bias": "e_score_correction_bias"}} + ), + "local_experts.linear_fc1": NameRemapping( + "backbone.layers.{}.mixer.experts.{}.up_proj.", COL_ETP + ), + "local_experts.linear_fc2": NameRemapping( + "backbone.layers.{}.mixer.experts.{}.down_proj.", ROW_ETP + ), + "shared_experts.linear_fc1": NameRemapping( + "backbone.layers.{}.mixer.shared_experts.up_proj.", COL_TP + ), + "shared_experts.linear_fc2": NameRemapping( + "backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP + ), } @@ -87,4 +105,14 @@ "pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm."), "linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj."), "linear_fc2": NameRemapping("backbone.layers.{}.mixer.down_proj."), + # MoE + "router": NameRemapping( + "backbone.layers.{}.mixer.gate.", {"mapping": {"expert_bias": "e_score_correction_bias"}} + ), + "local_experts.linear_fc1": NameRemapping("backbone.layers.{}.mixer.experts.{}.up_proj."), + "local_experts.linear_fc2": NameRemapping("backbone.layers.{}.mixer.experts.{}.down_proj."), + "shared_experts.linear_fc1": NameRemapping("backbone.layers.{}.mixer.shared_experts.up_proj."), + "shared_experts.linear_fc2": NameRemapping( + "backbone.layers.{}.mixer.shared_experts.down_proj." + ), } diff --git a/modelopt/torch/export/plugins/megatron_importer.py b/modelopt/torch/export/plugins/megatron_importer.py index 4c805dc01..0af79eb36 100644 --- a/modelopt/torch/export/plugins/megatron_importer.py +++ b/modelopt/torch/export/plugins/megatron_importer.py @@ -77,11 +77,18 @@ def __init__( dequantize: bool = True, trust_remote_code: bool = True, verbose: bool = False, + moe_router_dtype: torch.dtype | None = None, ): """Create a GPTModel importer instance.""" self._hf_config = transformers.AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code ) + self.moe_router_dtype = None + if moe_router_dtype == "fp32": + self.moe_router_dtype = torch.float32 + elif moe_router_dtype == "fp64": + self.moe_router_dtype = torch.float64 + pretrained_model_path = Path(pretrained_model_name_or_path) if not pretrained_model_path.is_dir(): if workspace_dir is None: @@ -118,7 +125,9 @@ def _custom_mapping_to_lambda(mapping): func = method_map[mapping.func_name] prefix = mapping.target_name_or_prefix func_kwargs = mapping.func_kwargs - return lambda m, *args: func(m, prefix.format(*args), **func_kwargs) + return lambda m, *args, **kwargs: func( + m, prefix.format(*args), **{**func_kwargs, **kwargs} + ) for arch, mappings in all_mcore_hf_import_mapping.items(): all_rules[arch] = { @@ -140,7 +149,10 @@ def _name_remapping( prefix, mapping={}, parallel_config: ParallelConfig | None = None, + dtype: torch.dtype | None = None, ): + if dtype is None: + dtype = self.dtype if isinstance(module, torch.Tensor): tensor = self._get_safetensor(prefix, parallel_config=parallel_config) module.data.copy_(tensor) @@ -193,7 +205,7 @@ def _name_remapping( tensor = self._get_safetensor( prefix + source_key, parallel_config=parallel_config ) - state_dict[key] = tensor.to(dtype=self.dtype).to(device=val.device) + state_dict[key] = tensor.to(dtype=dtype).to(device=val.device) module.load_state_dict(state_dict) @@ -523,7 +535,9 @@ def _import_state_dict(self): if not isinstance(layer.mlp, IdentityOp): if "MoE" in str(type(layer.mlp)): layer_pbar.set_description("Importing MoE") - self.rules["router"](layer.mlp.router, layer_id) + self.rules["router"]( + layer.mlp.router, layer_id, dtype=self.moe_router_dtype + ) if ( hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 586745a1b..70a80aeec 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -109,7 +109,7 @@ def get_kv_cache_scaling_factor(kv_module: nn.Module) -> torch.Tensor: def get_quantized_state( module: torch.nn.Module, - dtype: torch.dtype = torch.float16, + dtype: torch.dtype = torch.bfloat16, ) -> tuple[dict[str, torch.Tensor], str, int]: """Return a state_dict, quantization format, and block_size of the module. @@ -186,6 +186,7 @@ def __init__( export_extra_modules: bool = False, dtype=torch.bfloat16, trust_remote_code: bool = True, + moe_router_dtype: torch.dtype | None = None, ): """Create a GPTModel exporter instance.""" if not isinstance(model, (GPTModel, MambaModel, LLaVAModel)): @@ -196,6 +197,12 @@ def __init__( self._hf_config = transformers.AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code ) + self.moe_router_dtype = None + if moe_router_dtype == "fp32": + self.moe_router_dtype = torch.float32 + elif moe_router_dtype == "fp64": + self.moe_router_dtype = torch.float64 + # If multimodal, extra the text_config self._hf_text_config = getattr(self._hf_config, "text_config", self._hf_config) @@ -489,7 +496,9 @@ def _custom_mapping_to_lambda(mapping): func = method_map[mapping.func_name] prefix = mapping.target_name_or_prefix func_kwargs = mapping.func_kwargs - return lambda m, *args: func(m, prefix.format(*args), **func_kwargs) + return lambda m, *args, **kwargs: func( + m, prefix.format(*args), **{**func_kwargs, **kwargs} + ) for arch, mappings in all_mcore_hf_export_mapping.items(): all_rules[arch] = { @@ -519,12 +528,16 @@ def _name_remapping( prefix: str, skip_output_scale: bool = True, mapping={}, + dtype: torch.dtype | None = None, ): + if dtype is None: + dtype = self.dtype + if isinstance(module, torch.Tensor): self._state_dict[prefix] = module return - name_to_value, qformat, block_size = get_quantized_state(module, self.dtype) + name_to_value, qformat, block_size = get_quantized_state(module, dtype) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -1098,7 +1111,9 @@ def _get_state_dict(self): if not isinstance(layer.mlp, IdentityOp): if "MoE" in str(type(layer.mlp)): - self.rules["router"](layer.mlp.router, layer_id) + self.rules["router"]( + layer.mlp.router, layer_id, dtype=self.moe_router_dtype + ) if ( hasattr(layer.mlp, "shared_experts") and layer.mlp.shared_experts is not None @@ -1136,8 +1151,9 @@ def export_mcore_gpt_to_hf( model: torch.nn.Module, pretrained_model_name_or_path: str | os.PathLike | None = None, export_extra_modules: bool = False, - dtype: torch.dtype = torch.float16, + dtype: torch.dtype = torch.bfloat16, export_dir: Path | str = tempfile.gettempdir(), + moe_router_dtype: torch.dtype | None = None, ): """Export Megatron Core GPTModel to unified checkpoint and save to export_dir. @@ -1153,7 +1169,11 @@ def export_mcore_gpt_to_hf( export_dir: The target export path. """ exporter = GPTModelExporter( - model, pretrained_model_name_or_path, export_extra_modules=export_extra_modules, dtype=dtype + model, + pretrained_model_name_or_path, + export_extra_modules=export_extra_modules, + dtype=dtype, + moe_router_dtype=moe_router_dtype, ) exporter.save_pretrained(export_dir, pretrained_model_name_or_path) @@ -1162,7 +1182,8 @@ def import_mcore_gpt_from_hf( model: torch.nn.Module, pretrained_model_path: str, workspace_dir: str | None = None, - dtype: torch.dtype = torch.float16, + dtype: torch.dtype = torch.bfloat16, + moe_router_dtype: torch.dtype | None = None, ): """Import GPTModel state_dict from supported HuggingFace pretrained model path. @@ -1173,6 +1194,10 @@ def import_mcore_gpt_from_hf( dtype: The weights data type to import. """ importer = GPTModelImporter( - model, pretrained_model_path, workspace_dir=workspace_dir, dtype=dtype + model, + pretrained_model_path, + workspace_dir=workspace_dir, + dtype=dtype, + moe_router_dtype=moe_router_dtype, ) importer._import_state_dict() diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 061e71dba..aefd2cc6f 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -345,7 +345,16 @@ def backward(ctx, grad_output): _transposed_quantize = _TransposedQuantization.apply -class _QuantMoeSparseMoe(QuantModule): +class _QuantSparseMoe(QuantModule): + """Module to support special handling of token dispatching during calibration. + + During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate. + However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance + returns. + + If calibration is not enabled, this module behaves as a normal MoELayer. + """ + def _setup(self): pass @@ -480,7 +489,7 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor: return self.w2_linear[expert_idx](x1) -class _QuantDbrxFFN(_QuantMoeSparseMoe): +class _QuantDbrxFFN(_QuantSparseMoe): @property def num_experts(self): return self.router.moe_num_experts @@ -498,7 +507,7 @@ def top_k(self, value): from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe if Llama4TextMoe not in QuantModuleRegistry: - QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantMoeSparseMoe) + QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe) if Llama4TextExperts not in QuantModuleRegistry: QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})( @@ -526,7 +535,7 @@ def top_k(self, value): if MixtralSparseMoeBlock not in QuantModuleRegistry: QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})( - _QuantMoeSparseMoe + _QuantSparseMoe ) except ImportError: pass @@ -544,7 +553,7 @@ def top_k(self, value): if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry: QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})( - _QuantMoeSparseMoe + _QuantSparseMoe ) except ImportError: pass @@ -554,7 +563,7 @@ def top_k(self, value): if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry: QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})( - _QuantMoeSparseMoe + _QuantSparseMoe ) except ImportError: pass @@ -564,7 +573,7 @@ def top_k(self, value): if Qwen3NextSparseMoeBlock not in QuantModuleRegistry: QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})( - _QuantMoeSparseMoe + _QuantSparseMoe ) except ImportError: pass diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index fd6b0660d..91cb2f7a5 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -23,6 +23,7 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import megatron.core.transformer.moe.experts as megatron_moe +import megatron.core.transformer.moe.moe_layer as megatron_moe_layer import torch from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region @@ -36,7 +37,7 @@ ) from modelopt.torch.utils.distributed import ParallelState -from ..nn import QuantModuleRegistry, TensorQuantizer +from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear @@ -247,6 +248,14 @@ def _setup(self): data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), ) + + if getattr(self, "gradient_accumulation_fusion", False): + warnings.warn( + "gradient_accumulation_fusion is not supported with ModelOpt quantization. " + "Setting gradient_accumulation_fusion to False." + ) + self.gradient_accumulation_fusion = False + super()._setup() def _process_quantizer_amax(self, k, v, quantizer_state_dict): @@ -580,3 +589,26 @@ def _setup(self): # initialize parallel state for submodules linear_fc1 and linear_fc2 self.linear_fc1.parallel_state = self.parallel_state self.linear_fc2.parallel_state = self.parallel_state + + +@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"}) +class _QuantMoELayer(QuantModule): + """Module to support special handling of token dispatching during calibration. + + During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate. + However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance + returns. + + If calibration is not enabled, this module behaves as a normal MoELayer. + """ + + def _setup(self): + pass + + def forward(self, hidden_states): + if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): + original_top_k = self.router.topk + self.router.topk = self.router.num_experts + super().forward(hidden_states) + self.router.topk = original_top_k + return super().forward(hidden_states)