[ckpt] refactor: Consolidate fused expert mappings and fix MTP inference#2685
[ckpt] refactor: Consolidate fused expert mappings and fix MTP inference#2685
Conversation
Introduce FusedExpertMapping and FusedGatedExpertMapping in param_mapping.py to handle many-to-one / one-to-many expert weight conversions generically. This eliminates duplicated maybe_modify_converted_hf_weight overrides and hf_weights_cache from GPT-OSS, GLM-4.5, GLM-4.5V, and Qwen3-VL bridges (-502 / +307 lines). Also fixes two pre-existing bugs: - GLM-4.5 MTP mappings used stale 'transformer_layer' instead of 'mtp_model_layer', causing missing-mapping warnings - hf_to_megatron_generate_text.py set mtp_num_layers=None which crashed MTP-enabled models; replaced with m.mtp_process=False Signed-off-by: Yu Yao <yaoyu.094@gmail.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Made-with: Cursor
|
/ok to test ff3705b |
📝 WalkthroughWalkthroughThis PR introduces fused expert mapping classes and grouped export accumulation logic for optimized MoE weight conversion, replaces legacy per-expert mapping implementations across multiple model bridges (Qwen, GPT-OSS, GLM) with the new fused variants, and updates MTP inference handling by conditionally disabling the mtp_process attribute. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant ModelBridge as MegatronModelBridge
participant Accum as Grouped<br/>Accumulator
participant GrpBuf as GroupedBuffers<br/>(Tensor Cache)
participant Output as Merged<br/>Tensor Dict
Client->>ModelBridge: stream_weights_hf_to_megatron<br/>(with grouped_export mapping)
ModelBridge->>ModelBridge: Detect is_grouped_export=True
ModelBridge->>GrpBuf: Initialize grouped_buffers[group_key]
loop For Each Expert in Group
ModelBridge->>ModelBridge: Load HF weight slice
ModelBridge->>Accum: _accumulate_grouped_export<br/>(expert_idx, weight)
Accum->>GrpBuf: Store per-expert weight<br/>at global expert index
end
Accum->>Accum: All experts collected?
alt Yes - Group Complete
Accum->>Accum: Stack/merge expert<br/>tensors into single tensor
Accum->>Accum: Optionally transpose<br/>to match shape
Accum->>Output: Return merged dict
Output->>Client: Yield merged tensor
else No - Still Accumulating
Accum->>Client: Yield None (continue)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/megatron/bridge/models/glm/glm45_bridge.py (1)
218-300:⚠️ Potential issue | 🟠 MajorAdd dual-prefix support for MTP layer mappings to handle both Megatron-Core naming conventions.
The MTP mappings currently hard-code only
mtp_model_layerin the explicit QKV/MLP/expert mappings (lines 250, 256, 262, 267, 277, 284, 295, 300) and in the generatedAutoMappingentries at line 218. Megatron-Core may expose the MTP submodule astransformer_layerinstead, which will leave MTP weights unmapped for those checkpoints. Follow the pattern inmimo_bridge.pyby iterating over both prefixes to ensure compatibility across different Megatron-Core versions.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/models/glm/glm45_bridge.py` around lines 218 - 300, The MTP mappings only use the "mtp_model_layer" prefix causing missed mappings when Megatron exposes the submodule as "transformer_layer"; update the mapping construction to loop over both prefixes (e.g., prefixes = ["mtp_model_layer", "transformer_layer"]) and add mappings for each prefix so every place that currently constructs megatron_param with "mtp_model_layer" (including the AutoMapping entries and the specialized mappings: QKVMapping, GatedMLPMapping, GLMExpertGateUpProjMapping, GLMExpertDownProjMapping, and the existing AutoMapping for experts) is duplicated/created for the alternate "transformer_layer" prefix; follow the pattern used in mimo_bridge.py to generate entries for both prefixes and append them to mapping_list.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/conversion/hf_to_megatron_generate_text.py`:
- Around line 171-172: The current change only flips the model instance flag
m.mtp_process, but you must also disable MTP at the config level and clear
mixed-precision scaling to avoid NCCL hangs: when you see the block that checks
hasattr(m, "mtp_process") and sets m.mtp_process = False, also set
m.config.mtp_num_layers = None (or 0 if config expects an int) and set
m.grad_scale_func = None, using attribute existence checks before assignment to
avoid attribute errors; update the same function/section that handles
m.mtp_process so all three changes are applied together.
In `@src/megatron/bridge/models/glm/glm_moe_mappings.py`:
- Around line 21-23: Module currently only re-exports GLMExpertDownProjMapping
causing import-time failure where GLMExpertGateUpProjMapping is expected; add a
matching re-export for the gate mapping by importing the appropriate symbol from
megatron.bridge.models.conversion.param_mapping and aliasing it to
GLMExpertGateUpProjMapping (mirror the existing pattern used for
GLMExpertDownProjMapping), so downstream code that imports and instantiates
GLMExpertGateUpProjMapping will succeed.
In `@src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py`:
- Around line 121-130: The quantized-path returning _dequantize_mxfp4(blocks,
scales) doesn't mirror the direct-tensor branch's transpose for 3D expert
weights, causing expert tensors to keep HF layout; update the branch handling
blocks_key/scales_key so that after calling _dequantize_mxfp4 you detect if
hf_param contains ".mlp.experts." and the returned tensor has ndim == 3, then
transpose the last two axes (i.e., swap -1 and -2) before returning; locate this
logic around the hf_param string branch that references hf_state_dict,
_dequantize_mxfp4, and the ".mlp.experts." selector to apply the fix.
---
Outside diff comments:
In `@src/megatron/bridge/models/glm/glm45_bridge.py`:
- Around line 218-300: The MTP mappings only use the "mtp_model_layer" prefix
causing missed mappings when Megatron exposes the submodule as
"transformer_layer"; update the mapping construction to loop over both prefixes
(e.g., prefixes = ["mtp_model_layer", "transformer_layer"]) and add mappings for
each prefix so every place that currently constructs megatron_param with
"mtp_model_layer" (including the AutoMapping entries and the specialized
mappings: QKVMapping, GatedMLPMapping, GLMExpertGateUpProjMapping,
GLMExpertDownProjMapping, and the existing AutoMapping for experts) is
duplicated/created for the alternate "transformer_layer" prefix; follow the
pattern used in mimo_bridge.py to generate entries for both prefixes and append
them to mapping_list.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 32a098f8-3a4b-4da8-b6f7-927e1570c4c4
📒 Files selected for processing (10)
examples/conversion/hf_to_megatron_generate_text.pysrc/megatron/bridge/models/conversion/__init__.pysrc/megatron/bridge/models/conversion/model_bridge.pysrc/megatron/bridge/models/conversion/param_mapping.pysrc/megatron/bridge/models/glm/glm45_bridge.pysrc/megatron/bridge/models/glm/glm_moe_mappings.pysrc/megatron/bridge/models/glm_vl/glm_45v_bridge.pysrc/megatron/bridge/models/gpt_oss/gpt_oss_bridge.pysrc/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.pysrc/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py
💤 Files with no reviewable changes (1)
- src/megatron/bridge/models/glm_vl/glm_45v_bridge.py
| if hasattr(m, "mtp_process"): | ||
| m.mtp_process = False |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify where MTP is actually controlled in this repository
echo "=== Searching for mtp_process usage ==="
rg -n -C2 '\bmtp_process\b' --type=py
echo -e "\n=== Searching for mtp_num_layers usage ==="
rg -n -C2 '\bmtp_num_layers\b' --type=pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 31680
🏁 Script executed:
cat -n examples/conversion/hf_to_megatron_generate_text.py | head -200 | tail -60Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2666
🏁 Script executed:
# Check what grad_scale_func is and where it's used
echo "=== Searching for grad_scale_func ==="
rg -n 'grad_scale_func' --type=py | head -20Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1656
🏁 Script executed:
# Check if hf_to_megatron_generate_text.py handles any Qwen VL specific logic
echo "=== Context around MTP disable in generate_text.py ==="
rg -B10 -A5 'mtp_process' examples/conversion/hf_to_megatron_generate_text.py
echo -e "\n=== Compare with generate_vlm.py approach ==="
rg -B5 -A5 'mtp_num_layers = None' examples/conversion/hf_to_megatron_generate_vlm.py
echo -e "\n=== Check file imports to see what models are supported ==="
head -50 examples/conversion/hf_to_megatron_generate_text.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 3162
🏁 Script executed:
# Check if Qwen VL is the only model that uses mtp_process
echo "=== Check what models/classes have mtp_process attribute ==="
rg -B5 'self.mtp_process\s*=' --type=py | head -40
echo -e "\n=== Verify if all models have config.mtp_num_layers ==="
rg 'config.mtp_num_layers' examples/conversion --type=pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 957
🏁 Script executed:
# Final check: verify grad_scale_func is also needed for inference
echo "=== Check inference paths that set grad_scale_func to None ==="
rg -B3 -A3 'grad_scale_func = None' examples/conversion src/megatron/bridge/inference --type=pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1934
Set MTP config to None for inference; mtp_process alone is insufficient.
Lines 171–172 only toggle a model-specific attribute. For models where config.mtp_num_layers is set, MTP remains enabled at the config level, which training/inference logic checks to determine whether to use MTP. Additionally, grad_scale_func must be set to None to prevent NCCL collective hangs during inference.
Apply the following fix to match the pattern used in hf_to_megatron_generate_vlm.py and compare_hf_and_megatron/compare.py:
🔧 Proposed fix
for m in model:
+ if hasattr(m, "config"):
+ m.config.mtp_num_layers = None
+ m.config.grad_scale_func = None
if hasattr(m, "mtp_process"):
m.mtp_process = False📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if hasattr(m, "mtp_process"): | |
| m.mtp_process = False | |
| if hasattr(m, "config"): | |
| m.config.mtp_num_layers = None | |
| m.config.grad_scale_func = None | |
| if hasattr(m, "mtp_process"): | |
| m.mtp_process = False |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/conversion/hf_to_megatron_generate_text.py` around lines 171 - 172,
The current change only flips the model instance flag m.mtp_process, but you
must also disable MTP at the config level and clear mixed-precision scaling to
avoid NCCL hangs: when you see the block that checks hasattr(m, "mtp_process")
and sets m.mtp_process = False, also set m.config.mtp_num_layers = None (or 0 if
config expects an int) and set m.grad_scale_func = None, using attribute
existence checks before assignment to avoid attribute errors; update the same
function/section that handles m.mtp_process so all three changes are applied
together.
| from megatron.bridge.models.conversion.param_mapping import ( # noqa: F401 | ||
| FusedExpertMapping as GLMExpertDownProjMapping, | ||
| ) |
There was a problem hiding this comment.
Re-export GLMExpertGateUpProjMapping to prevent import-time breakage.
This module now exposes only GLMExpertDownProjMapping, but GLM bridges still import and instantiate GLMExpertGateUpProjMapping, which will fail at import time.
🔧 Proposed fix
from megatron.bridge.models.conversion.param_mapping import ( # noqa: F401
FusedExpertMapping as GLMExpertDownProjMapping,
+ FusedGatedExpertMapping as GLMExpertGateUpProjMapping,
)
+
+__all__ = ["GLMExpertDownProjMapping", "GLMExpertGateUpProjMapping"]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/models/glm/glm_moe_mappings.py` around lines 21 - 23,
Module currently only re-exports GLMExpertDownProjMapping causing import-time
failure where GLMExpertGateUpProjMapping is expected; add a matching re-export
for the gate mapping by importing the appropriate symbol from
megatron.bridge.models.conversion.param_mapping and aliasing it to
GLMExpertGateUpProjMapping (mirror the existing pattern used for
GLMExpertDownProjMapping), so downstream code that imports and instantiates
GLMExpertGateUpProjMapping will succeed.
| if isinstance(hf_param, str): | ||
| if hf_param in self.hf_weights_cache: | ||
| return self.hf_weights_cache[hf_param] | ||
| if hf_param in hf_state_dict: | ||
| hf_weights = hf_state_dict[hf_param] | ||
| if ".mlp.experts." in hf_param and len(hf_weights.shape) == 3: | ||
| if ".mlp.experts." in hf_param and hf_weights.ndim == 3: | ||
| hf_weights = hf_weights.transpose(-1, -2) | ||
| self.hf_weights_cache[hf_param] = hf_weights | ||
| else: | ||
| blocks_key = hf_param + "_blocks" | ||
| scales_key = hf_param + "_scales" | ||
| if blocks_key in hf_state_dict and scales_key in hf_state_dict: | ||
| hf_weights = _dequantize_mxfp4(hf_state_dict[blocks_key], hf_state_dict[scales_key]) | ||
| self.hf_weights_cache[hf_param] = hf_weights | ||
| else: | ||
| raise KeyError( | ||
| f"Cannot locate weights for '{hf_param}'. Missing both de-quantized tensor and " | ||
| f"quantized representation (blocks='{blocks_key}', scales='{scales_key}')." | ||
| ) | ||
| else: | ||
| hf_weights = {k: hf_state_dict[v] for k, v in hf_param.items()} | ||
| return hf_weights | ||
|
|
||
| def maybe_modify_converted_hf_weight( | ||
| self, | ||
| task: WeightConversionTask, | ||
| converted_weights_dict: Dict[str, torch.Tensor], | ||
| hf_state_dict: Mapping[str, torch.Tensor], | ||
| ) -> Dict[str, torch.Tensor]: | ||
| num_experts = self.hf_config.num_local_experts | ||
| ep_size = parallel_state.get_expert_model_parallel_world_size() | ||
| experts_per_rank = num_experts // ep_size | ||
|
|
||
| try: | ||
| local_expert_number = extract_expert_number_from_param(task.param_name) % experts_per_rank | ||
| except ValueError: | ||
| # not an expert weight | ||
| return converted_weights_dict | ||
|
|
||
| assert len(converted_weights_dict) == 1, ( | ||
| f"There should be only one key in the converted_weights_dict, got keys: {converted_weights_dict.keys()}" | ||
| ) | ||
| for key, value in converted_weights_dict.items(): | ||
| if key not in self.hf_weights_cache: | ||
| self.hf_weights_cache[key] = {} | ||
|
|
||
| # we end up with ep_size many weights to add to the cache | ||
| # unpack the weights and re-index | ||
| if ep_size == 1: | ||
| self.hf_weights_cache[key][local_expert_number] = value | ||
| else: | ||
| assert value.shape[0] == ep_size | ||
| for i, exp_val in enumerate(value): | ||
| global_expert_number = local_expert_number + (i * experts_per_rank) | ||
| self.hf_weights_cache[key][global_expert_number] = exp_val | ||
| if len(self.hf_weights_cache[key]) == num_experts: | ||
| logging.debug(f"All experts are loaded for {key}") | ||
| # all experts are loaded | ||
| merged_hf_weights = torch.cat( | ||
| [self.hf_weights_cache[key][i].unsqueeze(0) for i in range(num_experts)], dim=0 | ||
| ) | ||
| del self.hf_weights_cache[key] | ||
| return {key: merged_hf_weights} | ||
| else: | ||
| # not all experts are loaded yet, return empty dict | ||
| logging.debug(f"{len(self.hf_weights_cache[key])}/{num_experts} experts are loaded for {key}") | ||
| return {} | ||
| return hf_weights | ||
| blocks_key = hf_param + "_blocks" | ||
| scales_key = hf_param + "_scales" | ||
| if blocks_key in hf_state_dict and scales_key in hf_state_dict: | ||
| return _dequantize_mxfp4(hf_state_dict[blocks_key], hf_state_dict[scales_key]) |
There was a problem hiding this comment.
Transpose dequantized expert tensors before returning them.
The direct-tensor branch normalizes 3D .mlp.experts. weights, but the _blocks/_scales branch returns _dequantize_mxfp4(...) unchanged. On quantized GPT-OSS checkpoints that leaves each expert in HF layout, so GPTOSSMLPDownProjMapping and GPTOSSMLPGateUpProjMapping receive the wrong axis order on import.
🔧 Proposed fix
blocks_key = hf_param + "_blocks"
scales_key = hf_param + "_scales"
if blocks_key in hf_state_dict and scales_key in hf_state_dict:
- return _dequantize_mxfp4(hf_state_dict[blocks_key], hf_state_dict[scales_key])
+ hf_weights = _dequantize_mxfp4(hf_state_dict[blocks_key], hf_state_dict[scales_key])
+ if ".mlp.experts." in hf_param and hf_weights.ndim == 3:
+ hf_weights = hf_weights.transpose(-1, -2)
+ return hf_weights🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py` around lines 121 - 130,
The quantized-path returning _dequantize_mxfp4(blocks, scales) doesn't mirror
the direct-tensor branch's transpose for 3D expert weights, causing expert
tensors to keep HF layout; update the branch handling blocks_key/scales_key so
that after calling _dequantize_mxfp4 you detect if hf_param contains
".mlp.experts." and the returned tensor has ndim == 3, then transpose the last
two axes (i.e., swap -1 and -2) before returning; locate this logic around the
hf_param string branch that references hf_state_dict, _dequantize_mxfp4, and the
".mlp.experts." selector to apply the fix.
Summary
FusedExpertMappingandFusedGatedExpertMappinginparam_mapping.pyto handle many-to-one / one-to-many expert weight conversions generically viais_grouped_export/group_keyprotocolmaybe_modify_converted_hf_weightoverrides andhf_weights_cachefrom GPT-OSS, GLM-4.5, GLM-4.5V, and Qwen3-VL bridges (net -195 lines)_accumulate_grouped_exporttoMegatronModelBridgeand_hf_import_cachefor grouped import, centralizing the expert merge/split logictransformer_layerwithmtp_model_layerand propagatemtp_num_layersfrom HF confighf_to_megatron_generate_text.py: replacemtp_num_layers=None(crashes MTP-enabled models) withm.mtp_process=FalseTest plan
Made with Cursor
Summary by CodeRabbit
New Features
Improvements