Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions modelopt/torch/export/plugins/mcore_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
"""Custom mapping from Nemotron Hugging Face models to Megatron Core models."""

from .mcore_custom import (
COL_ETP,
COL_TP,
REPLICATE,
ROW_ETP,
ROW_TP,
CustomModuleMapping,
NameRemapping,
Expand Down Expand Up @@ -63,6 +65,22 @@
"pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE),
"linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj.", COL_TP),
"linear_fc2": NameRemapping("backbone.layers.{}.mixer.down_proj.", ROW_TP),
# MoE
"router": NameRemapping(
"backbone.layers.{}.mixer.gate.", {"mapping": {"expert_bias": "e_score_correction_bias"}}
),
"local_experts.linear_fc1": NameRemapping(
"backbone.layers.{}.mixer.experts.{}.up_proj.", COL_ETP
),
"local_experts.linear_fc2": NameRemapping(
"backbone.layers.{}.mixer.experts.{}.down_proj.", ROW_ETP
),
"shared_experts.linear_fc1": NameRemapping(
"backbone.layers.{}.mixer.shared_experts.up_proj.", COL_TP
),
"shared_experts.linear_fc2": NameRemapping(
"backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP
),
}


Expand All @@ -87,4 +105,14 @@
"pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm."),
"linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj."),
"linear_fc2": NameRemapping("backbone.layers.{}.mixer.down_proj."),
# MoE
"router": NameRemapping(
"backbone.layers.{}.mixer.gate.", {"mapping": {"expert_bias": "e_score_correction_bias"}}
),
"local_experts.linear_fc1": NameRemapping("backbone.layers.{}.mixer.experts.{}.up_proj."),
"local_experts.linear_fc2": NameRemapping("backbone.layers.{}.mixer.experts.{}.down_proj."),
"shared_experts.linear_fc1": NameRemapping("backbone.layers.{}.mixer.shared_experts.up_proj."),
"shared_experts.linear_fc2": NameRemapping(
"backbone.layers.{}.mixer.shared_experts.down_proj."
),
}
16 changes: 13 additions & 3 deletions modelopt/torch/export/plugins/megatron_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,18 @@ def __init__(
dequantize: bool = True,
trust_remote_code: bool = True,
verbose: bool = False,
moe_router_dtype: torch.dtype | None = None,
):
"""Create a GPTModel importer instance."""
self._hf_config = transformers.AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code
)
self.moe_router_dtype = None
if moe_router_dtype == "fp32":
self.moe_router_dtype = torch.float32
elif moe_router_dtype == "fp64":
self.moe_router_dtype = torch.float64

pretrained_model_path = Path(pretrained_model_name_or_path)
if not pretrained_model_path.is_dir():
if workspace_dir is None:
Expand Down Expand Up @@ -118,7 +125,7 @@ def _custom_mapping_to_lambda(mapping):
func = method_map[mapping.func_name]
prefix = mapping.target_name_or_prefix
func_kwargs = mapping.func_kwargs
return lambda m, *args: func(m, prefix.format(*args), **func_kwargs)
return lambda m, *args, **kwargs: func(m, prefix.format(*args), **{**func_kwargs, **kwargs})

for arch, mappings in all_mcore_hf_import_mapping.items():
all_rules[arch] = {
Expand All @@ -140,7 +147,10 @@ def _name_remapping(
prefix,
mapping={},
parallel_config: ParallelConfig | None = None,
dtype: torch.dtype | None = None,
):
if dtype is None:
dtype = self.dtype
if isinstance(module, torch.Tensor):
tensor = self._get_safetensor(prefix, parallel_config=parallel_config)
module.data.copy_(tensor)
Expand Down Expand Up @@ -193,7 +203,7 @@ def _name_remapping(
tensor = self._get_safetensor(
prefix + source_key, parallel_config=parallel_config
)
state_dict[key] = tensor.to(dtype=self.dtype).to(device=val.device)
state_dict[key] = tensor.to(dtype=dtype).to(device=val.device)

module.load_state_dict(state_dict)

Expand Down Expand Up @@ -523,7 +533,7 @@ def _import_state_dict(self):
if not isinstance(layer.mlp, IdentityOp):
if "MoE" in str(type(layer.mlp)):
layer_pbar.set_description("Importing MoE")
self.rules["router"](layer.mlp.router, layer_id)
self.rules["router"](layer.mlp.router, layer_id, dtype=self.moe_router_dtype)
if (
hasattr(layer.mlp, "shared_experts")
and layer.mlp.shared_experts is not None
Expand Down
31 changes: 22 additions & 9 deletions modelopt/torch/export/unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_kv_cache_scaling_factor(kv_module: nn.Module) -> torch.Tensor:

def get_quantized_state(
module: torch.nn.Module,
dtype: torch.dtype = torch.float16,
dtype: torch.dtype = torch.bfloat16,
) -> tuple[dict[str, torch.Tensor], str, int]:
"""Return a state_dict, quantization format, and block_size of the module.

Expand Down Expand Up @@ -186,6 +186,7 @@ def __init__(
export_extra_modules: bool = False,
dtype=torch.bfloat16,
trust_remote_code: bool = True,
moe_router_dtype: torch.dtype | None = None,
):
"""Create a GPTModel exporter instance."""
if not isinstance(model, (GPTModel, MambaModel, LLaVAModel)):
Expand All @@ -196,6 +197,12 @@ def __init__(
self._hf_config = transformers.AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code
)
self.moe_router_dtype = None
if moe_router_dtype == "fp32":
self.moe_router_dtype = torch.float32
elif moe_router_dtype == "fp64":
self.moe_router_dtype = torch.float64

# If multimodal, extra the text_config
self._hf_text_config = getattr(self._hf_config, "text_config", self._hf_config)

Expand Down Expand Up @@ -489,7 +496,7 @@ def _custom_mapping_to_lambda(mapping):
func = method_map[mapping.func_name]
prefix = mapping.target_name_or_prefix
func_kwargs = mapping.func_kwargs
return lambda m, *args: func(m, prefix.format(*args), **func_kwargs)
return lambda m, *args, **kwargs: func(m, prefix.format(*args), **{**func_kwargs, **kwargs})

for arch, mappings in all_mcore_hf_export_mapping.items():
all_rules[arch] = {
Expand Down Expand Up @@ -519,12 +526,16 @@ def _name_remapping(
prefix: str,
skip_output_scale: bool = True,
mapping={},
dtype: torch.dtype | None = None
):
if dtype is None:
dtype = self.dtype

if isinstance(module, torch.Tensor):
self._state_dict[prefix] = module
return

name_to_value, qformat, block_size = get_quantized_state(module, self.dtype)
name_to_value, qformat, block_size = get_quantized_state(module, dtype)

weight = name_to_value.pop("weight")
weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
Expand Down Expand Up @@ -1098,7 +1109,7 @@ def _get_state_dict(self):

if not isinstance(layer.mlp, IdentityOp):
if "MoE" in str(type(layer.mlp)):
self.rules["router"](layer.mlp.router, layer_id)
self.rules["router"](layer.mlp.router, layer_id, dtype=self.moe_router_dtype)
if (
hasattr(layer.mlp, "shared_experts")
and layer.mlp.shared_experts is not None
Expand Down Expand Up @@ -1136,8 +1147,9 @@ def export_mcore_gpt_to_hf(
model: torch.nn.Module,
pretrained_model_name_or_path: str | os.PathLike | None = None,
export_extra_modules: bool = False,
dtype: torch.dtype = torch.float16,
dtype: torch.dtype = torch.bfloat16,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change default dtypes to bf16, lmk if this might break anything

export_dir: Path | str = tempfile.gettempdir(),
moe_router_dtype: torch.dtype | None = None,
):
"""Export Megatron Core GPTModel to unified checkpoint and save to export_dir.

Expand All @@ -1153,7 +1165,7 @@ def export_mcore_gpt_to_hf(
export_dir: The target export path.
"""
exporter = GPTModelExporter(
model, pretrained_model_name_or_path, export_extra_modules=export_extra_modules, dtype=dtype
model, pretrained_model_name_or_path, export_extra_modules=export_extra_modules, dtype=dtype, moe_router_dtype=moe_router_dtype
)
exporter.save_pretrained(export_dir, pretrained_model_name_or_path)

Expand All @@ -1162,7 +1174,8 @@ def import_mcore_gpt_from_hf(
model: torch.nn.Module,
pretrained_model_path: str,
workspace_dir: str | None = None,
dtype: torch.dtype = torch.float16,
dtype: torch.dtype = torch.bfloat16,
moe_router_dtype: torch.dtype | None = None,
):
"""Import GPTModel state_dict from supported HuggingFace pretrained model path.

Expand All @@ -1173,6 +1186,6 @@ def import_mcore_gpt_from_hf(
dtype: The weights data type to import.
"""
importer = GPTModelImporter(
model, pretrained_model_path, workspace_dir=workspace_dir, dtype=dtype
model, pretrained_model_path, workspace_dir=workspace_dir, dtype=dtype, moe_router_dtype=moe_router_dtype
)
importer._import_state_dict()
importer._import_state_dict()
23 changes: 16 additions & 7 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,16 @@ def backward(ctx, grad_output):
_transposed_quantize = _TransposedQuantization.apply


class _QuantMoeSparseMoe(QuantModule):
class _QuantSparseMoe(QuantModule):
"""Module to support special handling of token dispatching during calibration.

During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate.
However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance
returns.

If calibration is not enabled, this module behaves as a normal MoELayer.
"""

def _setup(self):
pass

Expand Down Expand Up @@ -480,7 +489,7 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
return self.w2_linear[expert_idx](x1)


class _QuantDbrxFFN(_QuantMoeSparseMoe):
class _QuantDbrxFFN(_QuantSparseMoe):
@property
def num_experts(self):
return self.router.moe_num_experts
Expand All @@ -498,7 +507,7 @@ def top_k(self, value):
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe

if Llama4TextMoe not in QuantModuleRegistry:
QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantMoeSparseMoe)
QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe)

if Llama4TextExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})(
Expand Down Expand Up @@ -526,7 +535,7 @@ def top_k(self, value):

if MixtralSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})(
_QuantMoeSparseMoe
_QuantSparseMoe
)
except ImportError:
pass
Expand All @@ -544,7 +553,7 @@ def top_k(self, value):

if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})(
_QuantMoeSparseMoe
_QuantSparseMoe
)
except ImportError:
pass
Expand All @@ -554,7 +563,7 @@ def top_k(self, value):

if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})(
_QuantMoeSparseMoe
_QuantSparseMoe
)
except ImportError:
pass
Expand All @@ -564,7 +573,7 @@ def top_k(self, value):

if Qwen3NextSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})(
_QuantMoeSparseMoe
_QuantSparseMoe
)
except ImportError:
pass
Expand Down
34 changes: 33 additions & 1 deletion modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import megatron.core.transformer.moe.experts as megatron_moe
import megatron.core.transformer.moe.moe_layer as megatron_moe_layer
import torch
from megatron.core.parallel_state import get_data_parallel_group
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
Expand All @@ -36,7 +37,7 @@
)
from modelopt.torch.utils.distributed import ParallelState

from ..nn import QuantModuleRegistry, TensorQuantizer
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
from ..nn.modules.quant_linear import RealQuantLinear
from ..qtensor import QTensorWrapper
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
Expand Down Expand Up @@ -247,6 +248,14 @@ def _setup(self):
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
)

if getattr(self, "gradient_accumulation_fusion", False):
warnings.warn(
"gradient_accumulation_fusion is not supported with ModelOpt quantization. "
"Setting gradient_accumulation_fusion to False."
)
self.gradient_accumulation_fusion = False

super()._setup()

def _process_quantizer_amax(self, k, v, quantizer_state_dict):
Expand Down Expand Up @@ -580,3 +589,26 @@ def _setup(self):
# initialize parallel state for submodules linear_fc1 and linear_fc2
self.linear_fc1.parallel_state = self.parallel_state
self.linear_fc2.parallel_state = self.parallel_state


@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"})
class _QuantMoELayer(QuantModule):
"""Module to support special handling of token dispatching during calibration.

During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate.
However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance
returns.

If calibration is not enabled, this module behaves as a normal MoELayer.
"""

def _setup(self):
pass

def forward(self, hidden_states):
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
original_top_k = self.router.topk
self.router.topk = self.router.num_experts
super().forward(hidden_states)
self.router.topk = original_top_k
return super().forward(hidden_states)
Loading