Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions example/0.load_model_and_generate_single_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from megatron.core import parallel_state as mpu
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoConfig

from mbridge import AutoBridge

Expand All @@ -27,9 +27,17 @@ def init_distributed():

def load_model(hf_model_path, trust_remote_code=False):
"""Load model"""
bridge = AutoBridge.from_pretrained(

# use AutoConfig to change hf config
config = AutoConfig.from_pretrained(
hf_model_path, trust_remote_code=trust_remote_code
)

if hasattr(config, "num_nextn_predict_layers"):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we set this, is this a debug code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArronHZG This should not be needed once the patches here are applied #62 (comment)

config.num_nextn_predict_layers = 0

bridge = AutoBridge.from_config(config)

Comment on lines 7 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes should not be needed

model = bridge.get_model()
bridge.load_weights(model, hf_model_path)
return model
Expand Down
9 changes: 8 additions & 1 deletion mbridge/core/llm_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,18 @@ def _get_gptmodel_args(self) -> dict:
Returns:
dict: A dictionary of arguments for GPTModel initialization
"""
if hasattr(self.hf_config, 'rope_parameters') and 'rope_theta' in self.hf_config.rope_parameters:
# for transformer >= 5.0.0
rotary_base = self.hf_config.rope_parameters['rope_theta']
else:
# for transformer ~= 4.57.3
rotary_base = self.hf_config.rope_theta

return dict(
vocab_size=self.hf_config.vocab_size,
max_sequence_length=self.hf_config.max_position_embeddings,
position_embedding_type="rope",
rotary_base=self.hf_config.rope_theta,
rotary_base=rotary_base,
)

def _get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
Expand Down
2 changes: 1 addition & 1 deletion mbridge/core/safetensor_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class SafeTensorIO:
def __init__(self, hf_dir: str):
index_file = os.path.join(hf_dir, "model.safetensors.index.json")
config = AutoConfig.from_pretrained(hf_dir)
config = AutoConfig.from_pretrained(hf_dir, trust_remote_code=True)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't set trust_remove_code by default


self.index = {}
self.origin_index = {}
Expand Down
2 changes: 1 addition & 1 deletion mbridge/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _build_config(self):
mtp_args = {}
if "num_nextn_predict_layers" in hf_config:
mtp_args["mtp_num_layers"] = hf_config.num_nextn_predict_layers
mtp_args["mtp_loss_scaling_factor"] = 0.1
mtp_args["mtp_loss_scaling_factor"] = self.extra_args.get("mtp_loss_scaling_factor", 0.1)

base_config = {
"attention_backend": AttnBackend.fused,
Expand Down
47 changes: 40 additions & 7 deletions mbridge/models/mimo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from typing import Callable, Optional


from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec
import torch
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec, get_gpt_decoder_block_spec

from ..core import register_model
from .qwen2 import Qwen2Bridge
Expand All @@ -27,7 +28,7 @@ def _build_config(self):
mtp_args = {}
if "num_nextn_predict_layers" in hf_config:
mtp_args["mtp_num_layers"] = hf_config.num_nextn_predict_layers
mtp_args["mtp_loss_scaling_factor"] = 0.1
mtp_args["mtp_loss_scaling_factor"] = self.extra_args.get("mtp_loss_scaling_factor", 0.1)

return self._build_base_config(
add_qkv_bias=True,
Expand All @@ -41,10 +42,11 @@ def _get_gptmodel_args(self) -> dict:

# Add MTP block spec if MTP layers are present
if self.config.mtp_num_layers is not None:
transformer_layer_spec = self.config
mtp_block_spec = get_gpt_mtp_block_spec(
self.config, transformer_layer_spec, use_transformer_engine=True
)
transformer_layer_spec = get_gpt_decoder_block_spec(config=self.config,
use_transformer_engine=True)
mtp_block_spec = get_gpt_mtp_block_spec(self.config,
transformer_layer_spec.layer_specs[-1],
True)
ret["mtp_block_spec"] = mtp_block_spec

return ret
Expand Down Expand Up @@ -109,3 +111,34 @@ def _convert_mtp_param(self, name: str) -> list[str]:
else:
raise NotImplementedError(f"Unsupported MTP parameter name: {name}")
return convert_names


# The default MTP input concatenation order in Megatron is [target_token_embed, hidden_state],
# which is applicable to DeepSeek and GLM models;
# The MTP input format of the Mimo-7B-RL model is [hidden_state, target_token_embed],
# with the order completely reversed;
# It is necessary to swap the parameter values of the first half and the second half of eh_proj here.
#
# mimo mtp sglang code:
# https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/mimo_mtp.py#L61-L69
# mimo mtp vllm code:
# https://github.com/XiaomiMiMo/vllm/blob/feat_mimo_mtp_stable_073/vllm/model_executor/models/mimo_mtp.py#L84-L85
# megatron code:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/multi_token_prediction.py#L636

def _weight_to_mcore_format(self, mcore_weights_name: str, hf_weights: list[torch.Tensor]) -> torch.Tensor:
"""Swap halves of eh_proj weights before handing off to Megatron-Core."""
weight = super()._weight_to_mcore_format(mcore_weights_name, hf_weights)
if mcore_weights_name.endswith("eh_proj.weight"):
first_half, second_half = weight.chunk(2, dim=1)
weight = torch.cat([second_half, first_half], dim=1)
return weight

def _weight_to_hf_format(
self, mcore_weights_name: str, mcore_weights: torch.Tensor
) -> tuple[list[str], list[torch.Tensor]]:
"""Swap halves back when exporting eh_proj weights to HuggingFace format."""
if mcore_weights_name.endswith("eh_proj.weight"):
first_half, second_half = mcore_weights.chunk(2, dim=1)
mcore_weights = torch.cat([second_half, first_half], dim=1)
return super()._weight_to_hf_format(mcore_weights_name, mcore_weights)