Skip to content

[ckpt] refactor: Consolidate fused expert mappings and fix MTP inference#2685

Open
yaoyu-33 wants to merge 1 commit intomainfrom
yuya/refactor-fused-expert-mappings
Open

[ckpt] refactor: Consolidate fused expert mappings and fix MTP inference#2685
yaoyu-33 wants to merge 1 commit intomainfrom
yuya/refactor-fused-expert-mappings

Conversation

@yaoyu-33
Copy link
Contributor

@yaoyu-33 yaoyu-33 commented Mar 6, 2026

Summary

  • Introduce FusedExpertMapping and FusedGatedExpertMapping in param_mapping.py to handle many-to-one / one-to-many expert weight conversions generically via is_grouped_export / group_key protocol
  • Eliminate duplicated maybe_modify_converted_hf_weight overrides and hf_weights_cache from GPT-OSS, GLM-4.5, GLM-4.5V, and Qwen3-VL bridges (net -195 lines)
  • Add _accumulate_grouped_export to MegatronModelBridge and _hf_import_cache for grouped import, centralizing the expert merge/split logic
  • Fix GLM-4.5 MTP mappings: replace stale transformer_layer with mtp_model_layer and propagate mtp_num_layers from HF config
  • Fix hf_to_megatron_generate_text.py: replace mtp_num_layers=None (crashes MTP-enabled models) with m.mtp_process=False

Test plan

  • Pre-commit hooks pass (ruff lint + format)
  • GPT-OSS e2e conversion + inference
  • GLM-4.5 e2e conversion + inference
  • GLM-4.5V e2e conversion + inference
  • Qwen3-VL MoE e2e conversion + inference
  • Qwen3.5-VL MoE e2e conversion + inference
  • Unit tests pass

Made with Cursor

Summary by CodeRabbit

  • New Features

    • Introduced fused expert weight mappings for optimized Mixture of Experts model conversion
    • Expanded support for new model layer configurations
  • Improvements

    • Streamlined weight conversion workflows with enhanced expert weight alignment
    • Improved automatic weight shape adjustment for expert model layers
    • Simplified conversion logic across multiple supported architectures
    • Better handling of complex weight transformations and per-expert distributions

Introduce FusedExpertMapping and FusedGatedExpertMapping in
param_mapping.py to handle many-to-one / one-to-many expert weight
conversions generically. This eliminates duplicated
maybe_modify_converted_hf_weight overrides and hf_weights_cache from
GPT-OSS, GLM-4.5, GLM-4.5V, and Qwen3-VL bridges (-502 / +307 lines).

Also fixes two pre-existing bugs:
- GLM-4.5 MTP mappings used stale 'transformer_layer' instead of
  'mtp_model_layer', causing missing-mapping warnings
- hf_to_megatron_generate_text.py set mtp_num_layers=None which crashed
  MTP-enabled models; replaced with m.mtp_process=False

Signed-off-by: Yu Yao <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Made-with: Cursor
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 6, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yaoyu-33
Copy link
Contributor Author

yaoyu-33 commented Mar 6, 2026

/ok to test ff3705b

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 6, 2026

📝 Walkthrough

Walkthrough

This PR introduces fused expert mapping classes and grouped export accumulation logic for optimized MoE weight conversion, replaces legacy per-expert mapping implementations across multiple model bridges (Qwen, GPT-OSS, GLM) with the new fused variants, and updates MTP inference handling by conditionally disabling the mtp_process attribute.

Changes

Cohort / File(s) Summary
Core Fused Mapping Infrastructure
src/megatron/bridge/models/conversion/param_mapping.py, src/megatron/bridge/models/conversion/__init__.py
Introduces FusedExpertMapping and FusedGatedExpertMapping classes for handling fused expert weights with grouped export semantics, adds _align_expert_weight_to_shape helper for shape alignment, implements _LooseGatedMLPMapping with grouped-export validation skipping, and exports new mappings via all.
Model Bridge Grouped Export
src/megatron/bridge/models/conversion/model_bridge.py
Adds _accumulate_grouped_export method to handle per-expert weight accumulation and merging, integrates grouped export detection and routing into load_weights_hf_to_megatron and stream_weights flows, caches HF weights for reuse, and manages grouped_buffers for tensor stacking.
MTP Inference Example
examples/conversion/hf_to_megatron_generate_text.py
Replaces unconditional mtp_num_layers nulling with conditional mtp_process disabling via getattr check.
GLM Bridge Refactoring
src/megatron/bridge/models/glm/glm45_bridge.py, src/megatron/bridge/models/glm/glm_moe_mappings.py
Adds mtp_num_layers configuration support in GLM45Bridge, updates mtp_model_layer pattern mappings, removes legacy maybe_modify_converted_hf_weight method, consolidates GLM MoE mappings by aliasing GLMExpertDownProjMapping to FusedExpertMapping and removing internal implementations.
GLM-VL Bridge Simplification
src/megatron/bridge/models/glm_vl/glm_45v_bridge.py
Removes maybe_modify_converted_hf_weight method and unused imports, eliminating per-expert weight assembly and merging logic.
GPT-OSS Bridge Modernization
src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py
Removes weight caching and per-task conversion logic, adds is_grouped_export and group_key to GPTOSSMLPDownProjMapping and GPTOSSMLPGateUpProjMapping, simplifies maybe_modify_loaded_hf_weight to handle transposition and MXFP4 dequantization directly.
Qwen-VL Bridge Updates
src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py, src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py
Replaces ExpertMLPGateUpProjMapping and ExpertMLPDownProjMapping with FusedGatedExpertMapping and FusedExpertMapping respectively, removes legacy init methods and internal weight alignment helpers, updates imports to use new fused variants.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant ModelBridge as MegatronModelBridge
    participant Accum as Grouped<br/>Accumulator
    participant GrpBuf as GroupedBuffers<br/>(Tensor Cache)
    participant Output as Merged<br/>Tensor Dict

    Client->>ModelBridge: stream_weights_hf_to_megatron<br/>(with grouped_export mapping)
    ModelBridge->>ModelBridge: Detect is_grouped_export=True
    ModelBridge->>GrpBuf: Initialize grouped_buffers[group_key]
    
    loop For Each Expert in Group
        ModelBridge->>ModelBridge: Load HF weight slice
        ModelBridge->>Accum: _accumulate_grouped_export<br/>(expert_idx, weight)
        Accum->>GrpBuf: Store per-expert weight<br/>at global expert index
    end
    
    Accum->>Accum: All experts collected?
    alt Yes - Group Complete
        Accum->>Accum: Stack/merge expert<br/>tensors into single tensor
        Accum->>Accum: Optionally transpose<br/>to match shape
        Accum->>Output: Return merged dict
        Output->>Client: Yield merged tensor
    else No - Still Accumulating
        Accum->>Client: Yield None (continue)
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

  • PR #2589: Modifies MoE expert parameter mappings in HF bridge mapping registries for models like GPT-OSS, directly related to expert mapping refactoring scope.
  • PR #2358: Introduces fused/grouped-expert mapping classes and switching bridges to use them, shares the core grouped-export mapping infrastructure introduced in this PR.
  • PR #2336: Modifies MTP disabling logic in the same inference example file, addresses related MTP handling changes.

Suggested labels

Run CICD

Suggested reviewers

  • cuichenx
  • ko3n1g
  • liding-nv
🚥 Pre-merge checks | ✅ 2 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 48.72% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR contains substantial API changes and bug fixes but lacks documented test results; commit message explicitly states tests are pending rather than completed. Execute complete test suite for all affected model variants and document results before merging; address unresolved review comments regarding re-exports and MTP configuration.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the two main objectives: consolidating fused expert mappings and fixing MTP inference issues.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yuya/refactor-fused-expert-mappings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/megatron/bridge/models/glm/glm45_bridge.py (1)

218-300: ⚠️ Potential issue | 🟠 Major

Add dual-prefix support for MTP layer mappings to handle both Megatron-Core naming conventions.

The MTP mappings currently hard-code only mtp_model_layer in the explicit QKV/MLP/expert mappings (lines 250, 256, 262, 267, 277, 284, 295, 300) and in the generated AutoMapping entries at line 218. Megatron-Core may expose the MTP submodule as transformer_layer instead, which will leave MTP weights unmapped for those checkpoints. Follow the pattern in mimo_bridge.py by iterating over both prefixes to ensure compatibility across different Megatron-Core versions.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/models/glm/glm45_bridge.py` around lines 218 - 300, The
MTP mappings only use the "mtp_model_layer" prefix causing missed mappings when
Megatron exposes the submodule as "transformer_layer"; update the mapping
construction to loop over both prefixes (e.g., prefixes = ["mtp_model_layer",
"transformer_layer"]) and add mappings for each prefix so every place that
currently constructs megatron_param with "mtp_model_layer" (including the
AutoMapping entries and the specialized mappings: QKVMapping, GatedMLPMapping,
GLMExpertGateUpProjMapping, GLMExpertDownProjMapping, and the existing
AutoMapping for experts) is duplicated/created for the alternate
"transformer_layer" prefix; follow the pattern used in mimo_bridge.py to
generate entries for both prefixes and append them to mapping_list.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/conversion/hf_to_megatron_generate_text.py`:
- Around line 171-172: The current change only flips the model instance flag
m.mtp_process, but you must also disable MTP at the config level and clear
mixed-precision scaling to avoid NCCL hangs: when you see the block that checks
hasattr(m, "mtp_process") and sets m.mtp_process = False, also set
m.config.mtp_num_layers = None (or 0 if config expects an int) and set
m.grad_scale_func = None, using attribute existence checks before assignment to
avoid attribute errors; update the same function/section that handles
m.mtp_process so all three changes are applied together.

In `@src/megatron/bridge/models/glm/glm_moe_mappings.py`:
- Around line 21-23: Module currently only re-exports GLMExpertDownProjMapping
causing import-time failure where GLMExpertGateUpProjMapping is expected; add a
matching re-export for the gate mapping by importing the appropriate symbol from
megatron.bridge.models.conversion.param_mapping and aliasing it to
GLMExpertGateUpProjMapping (mirror the existing pattern used for
GLMExpertDownProjMapping), so downstream code that imports and instantiates
GLMExpertGateUpProjMapping will succeed.

In `@src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py`:
- Around line 121-130: The quantized-path returning _dequantize_mxfp4(blocks,
scales) doesn't mirror the direct-tensor branch's transpose for 3D expert
weights, causing expert tensors to keep HF layout; update the branch handling
blocks_key/scales_key so that after calling _dequantize_mxfp4 you detect if
hf_param contains ".mlp.experts." and the returned tensor has ndim == 3, then
transpose the last two axes (i.e., swap -1 and -2) before returning; locate this
logic around the hf_param string branch that references hf_state_dict,
_dequantize_mxfp4, and the ".mlp.experts." selector to apply the fix.

---

Outside diff comments:
In `@src/megatron/bridge/models/glm/glm45_bridge.py`:
- Around line 218-300: The MTP mappings only use the "mtp_model_layer" prefix
causing missed mappings when Megatron exposes the submodule as
"transformer_layer"; update the mapping construction to loop over both prefixes
(e.g., prefixes = ["mtp_model_layer", "transformer_layer"]) and add mappings for
each prefix so every place that currently constructs megatron_param with
"mtp_model_layer" (including the AutoMapping entries and the specialized
mappings: QKVMapping, GatedMLPMapping, GLMExpertGateUpProjMapping,
GLMExpertDownProjMapping, and the existing AutoMapping for experts) is
duplicated/created for the alternate "transformer_layer" prefix; follow the
pattern used in mimo_bridge.py to generate entries for both prefixes and append
them to mapping_list.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 32a098f8-3a4b-4da8-b6f7-927e1570c4c4

📥 Commits

Reviewing files that changed from the base of the PR and between 3e9d2e3 and ff3705b.

📒 Files selected for processing (10)
  • examples/conversion/hf_to_megatron_generate_text.py
  • src/megatron/bridge/models/conversion/__init__.py
  • src/megatron/bridge/models/conversion/model_bridge.py
  • src/megatron/bridge/models/conversion/param_mapping.py
  • src/megatron/bridge/models/glm/glm45_bridge.py
  • src/megatron/bridge/models/glm/glm_moe_mappings.py
  • src/megatron/bridge/models/glm_vl/glm_45v_bridge.py
  • src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py
  • src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py
  • src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py
💤 Files with no reviewable changes (1)
  • src/megatron/bridge/models/glm_vl/glm_45v_bridge.py

Comment on lines +171 to +172
if hasattr(m, "mtp_process"):
m.mtp_process = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify where MTP is actually controlled in this repository
echo "=== Searching for mtp_process usage ==="
rg -n -C2 '\bmtp_process\b' --type=py

echo -e "\n=== Searching for mtp_num_layers usage ==="
rg -n -C2 '\bmtp_num_layers\b' --type=py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 31680


🏁 Script executed:

cat -n examples/conversion/hf_to_megatron_generate_text.py | head -200 | tail -60

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2666


🏁 Script executed:

# Check what grad_scale_func is and where it's used
echo "=== Searching for grad_scale_func ==="
rg -n 'grad_scale_func' --type=py | head -20

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1656


🏁 Script executed:

# Check if hf_to_megatron_generate_text.py handles any Qwen VL specific logic
echo "=== Context around MTP disable in generate_text.py ==="
rg -B10 -A5 'mtp_process' examples/conversion/hf_to_megatron_generate_text.py

echo -e "\n=== Compare with generate_vlm.py approach ==="
rg -B5 -A5 'mtp_num_layers = None' examples/conversion/hf_to_megatron_generate_vlm.py

echo -e "\n=== Check file imports to see what models are supported ==="
head -50 examples/conversion/hf_to_megatron_generate_text.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 3162


🏁 Script executed:

# Check if Qwen VL is the only model that uses mtp_process
echo "=== Check what models/classes have mtp_process attribute ==="
rg -B5 'self.mtp_process\s*=' --type=py | head -40

echo -e "\n=== Verify if all models have config.mtp_num_layers ==="
rg 'config.mtp_num_layers' examples/conversion --type=py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 957


🏁 Script executed:

# Final check: verify grad_scale_func is also needed for inference
echo "=== Check inference paths that set grad_scale_func to None ==="
rg -B3 -A3 'grad_scale_func = None' examples/conversion src/megatron/bridge/inference --type=py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1934


Set MTP config to None for inference; mtp_process alone is insufficient.

Lines 171–172 only toggle a model-specific attribute. For models where config.mtp_num_layers is set, MTP remains enabled at the config level, which training/inference logic checks to determine whether to use MTP. Additionally, grad_scale_func must be set to None to prevent NCCL collective hangs during inference.

Apply the following fix to match the pattern used in hf_to_megatron_generate_vlm.py and compare_hf_and_megatron/compare.py:

🔧 Proposed fix
 for m in model:
+    if hasattr(m, "config"):
+        m.config.mtp_num_layers = None
+        m.config.grad_scale_func = None
     if hasattr(m, "mtp_process"):
         m.mtp_process = False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(m, "mtp_process"):
m.mtp_process = False
if hasattr(m, "config"):
m.config.mtp_num_layers = None
m.config.grad_scale_func = None
if hasattr(m, "mtp_process"):
m.mtp_process = False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/conversion/hf_to_megatron_generate_text.py` around lines 171 - 172,
The current change only flips the model instance flag m.mtp_process, but you
must also disable MTP at the config level and clear mixed-precision scaling to
avoid NCCL hangs: when you see the block that checks hasattr(m, "mtp_process")
and sets m.mtp_process = False, also set m.config.mtp_num_layers = None (or 0 if
config expects an int) and set m.grad_scale_func = None, using attribute
existence checks before assignment to avoid attribute errors; update the same
function/section that handles m.mtp_process so all three changes are applied
together.

Comment on lines +21 to +23
from megatron.bridge.models.conversion.param_mapping import ( # noqa: F401
FusedExpertMapping as GLMExpertDownProjMapping,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Re-export GLMExpertGateUpProjMapping to prevent import-time breakage.

This module now exposes only GLMExpertDownProjMapping, but GLM bridges still import and instantiate GLMExpertGateUpProjMapping, which will fail at import time.

🔧 Proposed fix
 from megatron.bridge.models.conversion.param_mapping import (  # noqa: F401
     FusedExpertMapping as GLMExpertDownProjMapping,
+    FusedGatedExpertMapping as GLMExpertGateUpProjMapping,
 )
+
+__all__ = ["GLMExpertDownProjMapping", "GLMExpertGateUpProjMapping"]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/models/glm/glm_moe_mappings.py` around lines 21 - 23,
Module currently only re-exports GLMExpertDownProjMapping causing import-time
failure where GLMExpertGateUpProjMapping is expected; add a matching re-export
for the gate mapping by importing the appropriate symbol from
megatron.bridge.models.conversion.param_mapping and aliasing it to
GLMExpertGateUpProjMapping (mirror the existing pattern used for
GLMExpertDownProjMapping), so downstream code that imports and instantiates
GLMExpertGateUpProjMapping will succeed.

Comment on lines 121 to +130
if isinstance(hf_param, str):
if hf_param in self.hf_weights_cache:
return self.hf_weights_cache[hf_param]
if hf_param in hf_state_dict:
hf_weights = hf_state_dict[hf_param]
if ".mlp.experts." in hf_param and len(hf_weights.shape) == 3:
if ".mlp.experts." in hf_param and hf_weights.ndim == 3:
hf_weights = hf_weights.transpose(-1, -2)
self.hf_weights_cache[hf_param] = hf_weights
else:
blocks_key = hf_param + "_blocks"
scales_key = hf_param + "_scales"
if blocks_key in hf_state_dict and scales_key in hf_state_dict:
hf_weights = _dequantize_mxfp4(hf_state_dict[blocks_key], hf_state_dict[scales_key])
self.hf_weights_cache[hf_param] = hf_weights
else:
raise KeyError(
f"Cannot locate weights for '{hf_param}'. Missing both de-quantized tensor and "
f"quantized representation (blocks='{blocks_key}', scales='{scales_key}')."
)
else:
hf_weights = {k: hf_state_dict[v] for k, v in hf_param.items()}
return hf_weights

def maybe_modify_converted_hf_weight(
self,
task: WeightConversionTask,
converted_weights_dict: Dict[str, torch.Tensor],
hf_state_dict: Mapping[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
num_experts = self.hf_config.num_local_experts
ep_size = parallel_state.get_expert_model_parallel_world_size()
experts_per_rank = num_experts // ep_size

try:
local_expert_number = extract_expert_number_from_param(task.param_name) % experts_per_rank
except ValueError:
# not an expert weight
return converted_weights_dict

assert len(converted_weights_dict) == 1, (
f"There should be only one key in the converted_weights_dict, got keys: {converted_weights_dict.keys()}"
)
for key, value in converted_weights_dict.items():
if key not in self.hf_weights_cache:
self.hf_weights_cache[key] = {}

# we end up with ep_size many weights to add to the cache
# unpack the weights and re-index
if ep_size == 1:
self.hf_weights_cache[key][local_expert_number] = value
else:
assert value.shape[0] == ep_size
for i, exp_val in enumerate(value):
global_expert_number = local_expert_number + (i * experts_per_rank)
self.hf_weights_cache[key][global_expert_number] = exp_val
if len(self.hf_weights_cache[key]) == num_experts:
logging.debug(f"All experts are loaded for {key}")
# all experts are loaded
merged_hf_weights = torch.cat(
[self.hf_weights_cache[key][i].unsqueeze(0) for i in range(num_experts)], dim=0
)
del self.hf_weights_cache[key]
return {key: merged_hf_weights}
else:
# not all experts are loaded yet, return empty dict
logging.debug(f"{len(self.hf_weights_cache[key])}/{num_experts} experts are loaded for {key}")
return {}
return hf_weights
blocks_key = hf_param + "_blocks"
scales_key = hf_param + "_scales"
if blocks_key in hf_state_dict and scales_key in hf_state_dict:
return _dequantize_mxfp4(hf_state_dict[blocks_key], hf_state_dict[scales_key])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Transpose dequantized expert tensors before returning them.

The direct-tensor branch normalizes 3D .mlp.experts. weights, but the _blocks/_scales branch returns _dequantize_mxfp4(...) unchanged. On quantized GPT-OSS checkpoints that leaves each expert in HF layout, so GPTOSSMLPDownProjMapping and GPTOSSMLPGateUpProjMapping receive the wrong axis order on import.

🔧 Proposed fix
             blocks_key = hf_param + "_blocks"
             scales_key = hf_param + "_scales"
             if blocks_key in hf_state_dict and scales_key in hf_state_dict:
-                return _dequantize_mxfp4(hf_state_dict[blocks_key], hf_state_dict[scales_key])
+                hf_weights = _dequantize_mxfp4(hf_state_dict[blocks_key], hf_state_dict[scales_key])
+                if ".mlp.experts." in hf_param and hf_weights.ndim == 3:
+                    hf_weights = hf_weights.transpose(-1, -2)
+                return hf_weights
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py` around lines 121 - 130,
The quantized-path returning _dequantize_mxfp4(blocks, scales) doesn't mirror
the direct-tensor branch's transpose for 3D expert weights, causing expert
tensors to keep HF layout; update the branch handling blocks_key/scales_key so
that after calling _dequantize_mxfp4 you detect if hf_param contains
".mlp.experts." and the returned tensor has ndim == 3, then transpose the last
two axes (i.e., swap -1 and -2) before returning; locate this logic around the
hf_param string branch that references hf_state_dict, _dequantize_mxfp4, and the
".mlp.experts." selector to apply the fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant