Skip to content

Conversation

@jenchen13
Copy link
Contributor

@jenchen13 jenchen13 commented Oct 17, 2025

What does this PR do?

Type of change: New feature

Overview: Support Mamba-MOE export for Nemotron H

Usage

# Add a code snippet demonstrating how to use this

Testing

Will test MLM import/export using MLM scripts

  • test import from HF
  • test export to HF - verify state dicts look the same

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features
    • Improved support for Mixture-of-Experts (MoE) models in import/export workflows.
    • Added handling for expert routing and both local and shared expert projection paths so MoE behavior is preserved during model transfer.
    • Enhanced MLP import/export metadata so router and expert components are correctly detected and exported.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 17, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 17, 2025

Walkthrough

Added four imports (COL_ETP, ROW_ETP, QKVMerging, QKVSlicing) to the plugin's public surface and extended Nemotron import/export mapping dictionaries to include Mixture-of-Experts (MoE) keys: router, local_experts.linear_fc1, local_experts.linear_fc2, shared_experts.linear_fc1, shared_experts.linear_fc2.

Changes

Cohort / File(s) Summary
Public imports
modelopt/torch/export/plugins/mcore_nemotron.py
Added imports exported from modelopt/torch/export/plugins/mcore_custom: COL_ETP, ROW_ETP, QKVMerging, QKVSlicing.
MoE mapping updates
modelopt/torch/export/plugins/mcore_nemotron.py
Extended nemotron_h_causal_lm_import and nemotron_h_causal_lm_export dictionaries with MoE entries: "router", "local_experts.linear_fc1", "local_experts.linear_fc2", "shared_experts.linear_fc1", "shared_experts.linear_fc2".

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Caller as Caller
    participant Plugin as mcore_nemotron.py
    participant McoreCustom as mcore_custom
    participant Mappings as Import/Export Dicts

    Note over Plugin,McoreCustom: expose additional symbols and extend mappings

    Caller->>Plugin: request nemotron mappings & exports
    Plugin->>McoreCustom: import COL_ETP, ROW_ETP, QKVMerging, QKVSlicing
    Plugin->>Mappings: build base nemotron import/export dicts
    Mappings->>Mappings: insert "router"
    Mappings->>Mappings: insert "local_experts.linear_fc1"/"linear_fc2"
    Mappings->>Mappings: insert "shared_experts.linear_fc1"/"linear_fc2"
    Plugin->>Caller: return updated mappings and exported symbols
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

🐇 I nibble at mappings, row and col,

Routers hum and experts scroll.
New imports hop into place,
Up and down the tensors race,
A cheerful rabbit bounds with grace.

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Title Check ✅ Passed The pull request title "Support MOE Export for Nemotron H" is fully aligned with the main change in the changeset. The PR's primary objective is to add support for Mamba-MOE export for Nemotron H, and the actual changes consist of adding MoE-related keys to the export mapping and importing additional components in the mcore_nemotron.py file. The title is concise, specific, and uses clear technical terminology (MOE for Mixture of Experts) that a team member scanning the history would immediately understand as referring to this feature addition.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jennifchen/nmh-moe-export

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Oct 17, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.40%. Comparing base (37c4974) to head (41357c8).
⚠️ Report is 8 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #447      +/-   ##
==========================================
- Coverage   73.40%   73.40%   -0.01%     
==========================================
  Files         180      180              
  Lines       18077    18127      +50     
==========================================
+ Hits        13270    13306      +36     
- Misses       4807     4821      +14     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jenchen13 jenchen13 marked this pull request as ready for review October 20, 2025 18:45
@jenchen13 jenchen13 requested a review from a team as a code owner October 20, 2025 18:45
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6ef9954 and 752c70a.

📒 Files selected for processing (1)
  • modelopt/torch/export/plugins/mcore_nemotron.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/plugins/mcore_nemotron.py (1)
modelopt/torch/export/plugins/mcore_custom.py (1)
  • NameRemapping (82-91)
🪛 GitHub Actions: Code Quality
modelopt/torch/export/plugins/mcore_nemotron.py

[error] 1-1: pre-commit checks failed. Ruff check reported issues and formatting changes were applied by hooks; the pre-commit run exited with code 1.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: partial-install (torch)
  • GitHub Check: partial-install (onnx)
  • GitHub Check: multi-transformers (min)
  • GitHub Check: multi-torch (26)
  • GitHub Check: multi-py (11)
  • GitHub Check: multi-torch (27)
  • GitHub Check: windows
  • GitHub Check: multi-py (10)
🔇 Additional comments (3)
modelopt/torch/export/plugins/mcore_nemotron.py (3)

20-22: LGTM! New imports are necessary for MoE support.

The COL_ETP and ROW_ETP imports are correctly added and used in the local_experts mappings below.


101-106: LGTM! MoE export mappings are consistent.

The MoE mappings for export are well-structured and consistent with the existing export dictionary patterns, using the correct "backbone.layers" prefix throughout.


1-1: All pre-commit and ruff checks now pass—no action required.

The trailing whitespace on line 76 and other ruff formatting issues have been resolved. Verification confirms zero remaining errors and no trailing whitespace in the file.

Signed-off-by: Jennifer Chen <[email protected]>
@jenchen13 jenchen13 force-pushed the jennifchen/nmh-moe-export branch from 02be52e to 15a8351 Compare October 20, 2025 19:04
@ChenhanYu ChenhanYu requested a review from yueshen2016 October 20, 2025 19:57
@jenchen13 jenchen13 force-pushed the jennifchen/nmh-moe-export branch from c2014a5 to d48514d Compare October 21, 2025 14:23
Signed-off-by: Jennifer Chen <[email protected]>
@jenchen13 jenchen13 force-pushed the jennifchen/nmh-moe-export branch from a86c7af to 614f4df Compare October 21, 2025 14:30
@jenchen13 jenchen13 requested a review from a team as a code owner October 27, 2025 21:31
Signed-off-by: jenchen13 <[email protected]>
Comment on lines 37 to 608
)
from modelopt.torch.utils.distributed import ParallelState

from ..nn import QuantModuleRegistry, TensorQuantizer
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
from ..nn.modules.quant_linear import RealQuantLinear
from ..qtensor import QTensorWrapper
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear

try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelGroupedLinear,
TERowParallelGroupedLinear,
)

from .transformer_engine import _QuantTEGroupedLinear

HAS_TE = True
except ImportError:
HAS_TE = False

logger = logging.getLogger(__name__)

__all__ = []


def real_quant_module_get_extra_state(self) -> dict:
"""Populating real_quantizer_state and q_tensor_state."""
extra_state = {}

if isinstance(self, RealQuantLinear) and isinstance(self.weight, QTensorWrapper):
real_quantizer_state = self.weight_quantizer.get_modelopt_state()
q_tensor_state = self.weight.get_state()
elif isinstance(self, RealQuantLinear):
real_quantizer_state = self.weight_quantizer.get_modelopt_state()
q_tensor_state = {}
else:
real_quantizer_state = None
q_tensor_state = None

extra_state["modelopt_real_quantizer_state"] = real_quantizer_state
extra_state["modelopt_q_tensor_state"] = q_tensor_state

return extra_state


def quant_module_get_extra_state(self) -> dict:
"""Populating the extra_state when state_dict() is called.
quantizer_state, real_quantizer_state, and q_tensor_state are usually stored
with in the modelopt_state metadata where the keys are the full module name. The issue
is that NeMo-MCore model's full module name can change
if pipeline-parallelism (PP) and expert-parallelism (EP)
are changing. Alternatively, we store quantizer_state in
QuantModule's extra_state with QuantModule.get_extra_state()
which avoids the need to store the full module name.
"""
extra_state = {}

is_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False

if not is_enabled:
return extra_state

quantizer_state = {}
for name, module in self.named_modules():
if isinstance(module, TensorQuantizer):
quantizer_state[name] = module.get_modelopt_state()

extra_state["modelopt_quantizer_state"] = quantizer_state

# Handle real_quantizer_state and q_tensor_state
extra_state.update(real_quant_module_get_extra_state(self))

return extra_state


def real_quant_module_set_extra_state(self, state: Any):
"""Restore q_tensor_state when load_state_dict() is called.
We skip restoring real_quantizer_state (if exists), since it is the same as
the weight_quantizer fake quantizer_state.
Finally, q_tensor_state is restored if meta device initialization is used. During
meta-device initialization, real_quantize is not called.
QTensorWrapper should replace the original weight parameter. Due to TP, we also need
to adjust q_tensor_data_shape and its metadata shape attribute to use the local weight shape.
When not using meta device initialization, real_quantize is called during compress mode
restore where the QTensor will be recomputed based on the local weights. Hence we don't
need to restore q_tensor_state.
Note:
The entire restore process can happen on meta device and be materialized later
with to_empty(). However, to_empty() will reassign the parameter and the
QTensorWrapper will be removed. We patch RealQuantLinear._apply to preserve
QTensorWrapper when to_empty() is applied.
"""
q_tensor_state = state.get("modelopt_q_tensor_state", None)

if q_tensor_state:
q_tensor_metadata = q_tensor_state["metadata"]
q_tensor_metadata["shape"] = self.weight.shape
q_tensor_data_dtype = q_tensor_state["quantized_data.dtype"]
q_tensor_shape = self.weight.shape

# If q_tensor_data_type is uint8, then it is compressed format of 2 elements.
if q_tensor_data_dtype == torch.uint8:
q_tensor_shape = list(q_tensor_shape)
q_tensor_shape[-1] = q_tensor_shape[-1] // 2
q_tensor_shape = torch.Size(q_tensor_shape)

self._parameters["weight"] = QTensorWrapper(
qtensor=torch.empty(
q_tensor_shape, # Use the local shape directly (TP-aware)
dtype=q_tensor_data_dtype,
device=self.weight.device,
),
metadata=q_tensor_metadata,
)


def quant_module_set_extra_state(self, state: Any):
"""Restore quantizer_state when load_state_dict() is called.
With quantizer_state stored in extra_state (NeMo-MCore `torch-dist`),
set_extra_state() is used to perform the functionality
conversion.restore_quantizer_state().
load_state_dict() are called twice during NeMo-MCore resume.
The state_dict only contains the extra_state in the first time.
set_extra_state() is trigger by the end of the load_state_dict()
where QuantModule.modelopt_post_restore() will reinitialize
amax and scalars to the correct shape.
The 2nd load_state_dict() is loading all states including amax and
scalars. We disable QuantModule.modelopt_post_restore() to avoid
reinitialization since set_extra_state() is called at the end.
We first restore all fake quantizer_state. Per QuantModule can have
weight_quantizer, input_quantizer, and output_quantizer.
Once all quantizer_state are resumed, modelopt_post_restore() is called
to adjust the shape of all buffers (amax, pre_qunat_scale, _scale, ...) since
the local shape can be different from the shape in the state due to change
in tensor parallelism (TP).
"""
if state is None or not self.allow_post_restore:
return

quantizer_state = state.get("modelopt_quantizer_state", None)

if quantizer_state is not None:
for name, module in self.named_modules():
if isinstance(module, TensorQuantizer):
module.set_from_modelopt_state(quantizer_state[name], properties_only=False)
self.modelopt_post_restore()

# Handle real_quantizer_state and q_tensor_state
real_quant_module_set_extra_state(self, state)

self.allow_post_restore = False


def megatron_replace_quant_module_hook(model: torch.nn.Module):
"""Configure Megatron-Core model quantization support.
This callback is called before the QuantModule replacement to reuse the current
custom callback infra. However, it is meant to target each QuantModule.
Since the callback is called when megatron is installed, we do a type check on
MegatronModule first. For each MegatronModule,
1. We change TransformerConfig to enable heterogenous distributed checkpointing.
2. We enable all sub- QuantModule to store quantizer_state as extra_state by
typing-matching the QuantModuleRegistry.
"""

def _register_extra_state_callbacks(model: torch.nn.Module):
for name, module in model.named_modules():
if type(module) in QuantModuleRegistry:
# This module will be replaced as a QuantModule
register_modelopt_extra_state_callbacks(
module,
quant_module_get_extra_state,
quant_module_set_extra_state,
)

for name, module in model.named_modules():
if isinstance(module, MegatronModule):
if "vision_model" not in name:
# We only enable hetereogenous_dist_checkpoint for language model, vision model is not quantized
module.config.hetereogenous_dist_checkpoint = True
_register_extra_state_callbacks(module)


CUSTOM_MODEL_PLUGINS.add(megatron_replace_quant_module_hook)


class _MegatronParallelLinear(_ParallelLinear):
_functionals_to_replace = [
(megatron_parallel, "linear_with_grad_accumulation_and_async_allreduce"),
(megatron_parallel, "linear_with_frozen_weight"),
]

def _setup(self):
if not hasattr(self, "parallel_state") or self.parallel_state is None:
data_parallel_group = None
try:
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
except AssertionError:
logger.warning(
"Context parallel group is not initialized, using data parallel group"
)
data_parallel_group = get_data_parallel_group()
self.parallel_state = ParallelState(
data_parallel_group,
mcore_parallel.get_tensor_model_parallel_group(),
)

if getattr(self, "gradient_accumulation_fusion", False):
warnings.warn(
"gradient_accumulation_fusion is not supported with ModelOpt quantization. "
"Setting gradient_accumulation_fusion to False."
)
self.gradient_accumulation_fusion = False

super()._setup()

def _process_quantizer_amax(self, k, v, quantizer_state_dict):
if v.ndim == 4:
quantizer_state_dict[k] = v.squeeze(1).squeeze(-1)
else:
quantizer_state_dict[k] = (
v.view(self.weight.shape[0], -1) if v.numel() > 1 else v.view(-1)
)

def _process_activation_quantizer_pre_quant_scale(self, k, v, quantizer_state_dict):
quantizer_state_dict[k] = v

def _get_shard_axis_dict(self, state_dict):
raise NotImplementedError

def _parameter_to_keep_in_quantizer_state_dict(self, key):
"""Determine whether a parameter should be kept in the quantizer_state_dict.
Used to include additional quantization parameters (e.g., _scale for real quant)
beyond the default amax and pre_quant_scale tensors.
Note: When adding parameters here, update _get_shard_axis_dict accordingly for sharding.
"""
return False

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
# [WAR]: although we disable output_layer quantization by default but it will
# still be picked up by mtq.quantize since it is a ColumnParallelLinear. We need
# to further ensure that its sharded state_dict has no scalars or amax since
# 1) NeMo-MCore's vocabulary padding may change but we didn't support this feature
# 2) When embedding and output_layer are sharing weights, PP>1 will have
# output_layer.input_quantizer._amax but TP-only does not. This lead to
# state_dict mismatch.
if prefix.endswith("output_layer."):
# assert not any("_quantizer" in k for k in self.state_dict()), "quantized output_layer"
return super().sharded_state_dict(prefix, sharded_offsets)

quantizer_state_dict = {}
for k, v in self.state_dict(prefix="", keep_vars=True).items():
if "_quantizer" in k and "_amax" in k:
self._process_quantizer_amax(k, v, quantizer_state_dict)
elif k == "input_quantizer._pre_quant_scale":
self._process_activation_quantizer_pre_quant_scale(k, v, quantizer_state_dict)
elif self._parameter_to_keep_in_quantizer_state_dict(k):
quantizer_state_dict[k] = v
elif "quantizer" in k:
warnings.warn(
f"Quantizer state {k} is not supported for sharded_state_dict. "
"Please use regular state_dict."
)
sharded_axis_dict = self._get_shard_axis_dict(quantizer_state_dict)
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets)
sharded_state_dict.update(
**make_sharded_tensors_for_checkpoint(
quantizer_state_dict, prefix, sharded_axis_dict, sharded_offsets
)
)
return sharded_state_dict

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
for k in list(state_dict.keys()):
if not any(qt + "_quantizer" in k for qt in ["weight", "input", "output"]):
continue
name = k.split(prefix)[-1] if prefix else k
state_dict[k] = state_dict[k].view_as(self.state_dict()[name])
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)


@QuantModuleRegistry.register(
{megatron_parallel.ColumnParallelLinear: "megatron_ColumnParallelLinear"}
)
class _MegatronColumnParallelLinear(_MegatronParallelLinear):
_is_column_parallel = True

def _get_shard_axis_dict(self, state_dict):
"""Getting the sharded axis for amax and pre_quant_scale.
By default, ColumnParallelLinear shards the output dimension (dim=0). However,
depending the quantization algorithm, not all amax or pre_quant_scale need
to be sharded.
We check the quantizer.axis to decide whether an amax needs to be sharded.
Except for dynamic block quantization (NVFP4, axis: None) or per-tensor (FP8,
axis: None), the rest of algorithms all need to be sharded
Prequant scaling is applied per-input-channel; hence no sharding is required.
"""
shard_axis_dict = {}
for k in state_dict:
if "weight_quantizer." in k:
weight_quantizer_axis = self.get_submodule(k.rsplit(".", 1)[0]).axis
if weight_quantizer_axis is not None:
shard_axis_dict[k] = 0
return shard_axis_dict


@QuantModuleRegistry.register({megatron_parallel.RowParallelLinear: "megatron_RowParallelLinear"})
class _MegatronRowParallelLinear(_MegatronParallelLinear):
_is_row_parallel = True

def _get_shard_axis_dict(self, state_dict):
"""Getting the sharded axis for amax and pre_quant_scale.
By default, RowParallelLinear shards the input dimension (dim=1). However,
depending the quantization algorithm, not all amax or pre_quant_scale need
to be shard.
We check the quantizer.axis to decide whether an amax needs to be sharded.
Only static block quantization needs to be sharded and its axis is either (0,) or (0, 2).
The first case is used in AWQ the later case is used in blocked 2D quantization.
Dynamic block quantization (NVFP4 axis:None), per-tensor (FP8, axis: None)
and per-channel (INT8_SQ or FP8_PER_CHANNEL, axis: 1) do not require input sharding.
Prequant scaling is applied per-input-channel; hence it is always sharded.
"""
shard_axis_dict = {}
for k in state_dict:
if "weight_quantizer." in k:
weight_quantizer_axis = None
if isinstance(self.weight_quantizer, TensorQuantizer):
weight_quantizer_axis = self.weight_quantizer.axis
elif "weight_quantizer.0." in k:
weight_quantizer_axis = self.weight_quantizer[0].axis
elif "weight_quantizer.1." in k:
weight_quantizer_axis = self.weight_quantizer[1].axis
if isinstance(weight_quantizer_axis, tuple):
shard_axis_dict[k] = 1
if k == "input_quantizer._pre_quant_scale":
shard_axis_dict[k] = 0
return shard_axis_dict


@QuantModuleRegistry.register({megatron_mlp.MLP: "megatron_MegatronMLP"})
class _QuantMegatronMLP(_MegatronMLP):
"""Module to support special handling of `linear_fc1` in `sharded_state_dict()` of MCore `MLP`."""

_modelopt_state_keys = [
r"weight_quantizer\.(\d+\.)*_amax$",
r"weight_quantizer\.(\d+\.)*_scale$",
]


class _RealQuantMegatronParallelLinear(RealQuantLinear):
allow_real_quant_gemm = True
_scale_tensor_shard_axis = None

def _parameter_to_keep_in_quantizer_state_dict(self, key):
return any(k in key for k in self.list_of_scale_tensors)

def _get_shard_axis_dict(self, state_dict):
shard_axis_dict = super()._get_shard_axis_dict(state_dict)
for k in state_dict:
if (
any(k.endswith(suffix) for suffix in self.list_of_scale_tensors)
and state_dict[k].dim() > 1
):
assert self._scale_tensor_shard_axis is not None, (
"scale_tensor_shard_axis is not set, please set it in the subclass"
)
shard_axis_dict[k] = self._scale_tensor_shard_axis
return shard_axis_dict

def modelopt_post_restore(self, prefix: str = ""):
"""Post restore to correctly configure the realquant scales.
ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their
shape before saving. However this is not enough for MCore/distributed frameworks since the tensor parallelism
could change between saving and restoring. If the tensor parallelism changes, the shape of the quantizer
states also changes. So we need to re-calculate the quantizer states.
Note:
During real quantization, weight_quantizer._fake_quant is set to False which trigger the real quant
forward path and lead to error. We enable the weight_quantizer fake_quant forward path while recompute
the correct shape.
"""
self.weight_quantizer._fake_quant = True
super().modelopt_post_restore(prefix=prefix)
self.weight_quantizer._fake_quant = False

if hasattr(self.weight_quantizer, "_scale"):
# Recompute all real quantization buffer shapes
self.weight_quantizer._real_quantize(self.weight)

def _forward_impl(self, input, *args, **kwargs):
"""Use real quant gemm if available.
Here the forward is patched such that real quant gemm can be called if available. Both conditions
below must be satisfied (static and dynamic check based on input args) to use the kernel.
Otherwise, we fallback.
Note:
RealQuantLinear.forward() is doing the same check inside and will fall back to use the super
class forward(). This is not desired since _forward_impl introduces much more args and kwargs
while the original forward only takes 1 positional argument. We must above the fallback path
in RealQuantLinear.forward().
"""
if (
self._should_run_real_quant_gemm
and input.numel() > 1
and self.has_real_quant_gemm_impl(input, *args, **kwargs)
):
allreduce_dgrad = kwargs.get("allreduce_dgrad", False)
tp_group = kwargs.get("tp_group")
sequence_parallel = kwargs.get("sequence_parallel", False)

tp_group = get_tensor_model_parallel_group_if_none(tp_group)

if sequence_parallel:
input = gather_from_sequence_parallel_region(
input, tensor_parallel_output_grad=True, group=tp_group
)
else:
input = input

return RealQuantLinear.forward(
self,
input,
allreduce_dgrad=allreduce_dgrad,
tp_group=tp_group,
)
else:
return super()._forward_impl(input, *args, **kwargs)


class _RealQuantMegatronColumnParallelLinear(
_RealQuantMegatronParallelLinear, _MegatronColumnParallelLinear
):
_scale_tensor_shard_axis = 0

def forward(self, input, *args, **kwargs):
return _MegatronColumnParallelLinear.forward(self, input, *args, **kwargs)


class _RealQuantMegatronRowParallelLinear(
_RealQuantMegatronParallelLinear, _MegatronRowParallelLinear
):
_scale_tensor_shard_axis = 1

def forward(self, input, *args, **kwargs):
return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)


@QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"})
class _MegatronSequentialMLP(_MegatronMLP):
def _setup(self):
if not hasattr(self, "parallel_state") or self.parallel_state is None:
self.parallel_state = ParallelState(
mcore_parallel.get_expert_data_parallel_group(),
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
)

# Initialize parallel state for submodules local_experts.*.linear_fc1 and local_experts.*.linear_fc2
for expert in self.local_experts:
expert.linear_fc1.parallel_state = self.parallel_state
expert.linear_fc2.parallel_state = self.parallel_state

def sync_moe_local_experts_amax(self):
"""Sync amax across local experts in a SequentialMLP.
amax across EP and ETP (for RowParallel) are synchronized as part of model_calib.max_calibrate().
This function is called to synchronize the amax values across local experts s.t. all localexperts will
share the same amax.
"""
torch.distributed.barrier()
# Collect amax from all local experts
amax_dict = {}
for expert in self.local_experts:
for name, module in expert.named_modules():
if isinstance(module, TensorQuantizer) and module.amax is not None:
stored_amax = amax_dict.get(name)
amax_tensor = module.amax.detach().clone()
amax_dict[name] = (
amax_tensor
if stored_amax is None
else torch.maximum(stored_amax, amax_tensor)
)

# Apply synchronized amax values back to all local experts
for expert in self.local_experts:
for name, module in expert.named_modules():
if isinstance(module, TensorQuantizer) and module.amax is not None:
module.amax = amax_dict[name].detach().clone().to(module.amax.device)


if HAS_TE:
# Quantized subclasses to support TEGroupedMLP quantization
class _QuantMegatronTEGroupedLinear(_QuantTEGroupedLinear, _MegatronParallelLinear):
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
# _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
# sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for
# TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms]
# for modelopt checkpoint restore
filtered_state_dict = {
k: v
for k, v in state_dict.items()
if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms))
}
return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs)

def _process_quantizer_amax(self, k, v, quantizer_state_dict):
assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization"
quantizer_state_dict[k] = v.view(-1)

@QuantModuleRegistry.register(
{TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"}
)
class _MegatronTEGroupedColumnParallelLinear(
_QuantMegatronTEGroupedLinear, _MegatronColumnParallelLinear
):
pass

@QuantModuleRegistry.register(
{TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"}
)
class _MegatronTEGroupedRowParallelLinear(
_QuantMegatronTEGroupedLinear, _MegatronRowParallelLinear
):
pass

@QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"})
class _MegatronTEGroupedMLP(_MegatronMLP):
def _setup(self):
if not hasattr(self, "parallel_state") or self.parallel_state is None:
self.parallel_state = ParallelState(
mcore_parallel.get_expert_data_parallel_group(),
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
)
# initialize parallel state for submodules linear_fc1 and linear_fc2
self.linear_fc1.parallel_state = self.parallel_state
self.linear_fc2.parallel_state = self.parallel_state


@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"})
class _QuantMoELayer(QuantModule):
"""Module to support special handling of token dispatching during calibration.
During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate.
However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance
returns.
If calibration is not enabled, this module behaves as a normal MoELayer.
"""

def _setup(self):
pass

def forward(self, hidden_states):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@realAsma since both the megatron and HF Quant MOE implementation have the same forward() function, is there anyway for them to share code through inheritance? instead of copying the same code in both places

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants