diff --git a/example/0.load_model_and_generate_single_gpu.py b/example/0.load_model_and_generate_single_gpu.py index ad9263c..3c0590c 100644 --- a/example/0.load_model_and_generate_single_gpu.py +++ b/example/0.load_model_and_generate_single_gpu.py @@ -4,7 +4,7 @@ import torch from megatron.core import parallel_state as mpu from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoConfig from mbridge import AutoBridge @@ -27,9 +27,17 @@ def init_distributed(): def load_model(hf_model_path, trust_remote_code=False): """Load model""" - bridge = AutoBridge.from_pretrained( + + # use AutoConfig to change hf config + config = AutoConfig.from_pretrained( hf_model_path, trust_remote_code=trust_remote_code ) + + if hasattr(config, "num_nextn_predict_layers"): + config.num_nextn_predict_layers = 0 + + bridge = AutoBridge.from_config(config) + model = bridge.get_model() bridge.load_weights(model, hf_model_path) return model diff --git a/mbridge/core/llm_bridge.py b/mbridge/core/llm_bridge.py index a3667fe..fd42917 100644 --- a/mbridge/core/llm_bridge.py +++ b/mbridge/core/llm_bridge.py @@ -117,11 +117,18 @@ def _get_gptmodel_args(self) -> dict: Returns: dict: A dictionary of arguments for GPTModel initialization """ + if hasattr(self.hf_config, 'rope_parameters') and 'rope_theta' in self.hf_config.rope_parameters: + # for transformer >= 5.0.0 + rotary_base = self.hf_config.rope_parameters['rope_theta'] + else: + # for transformer ~= 4.57.3 + rotary_base = self.hf_config.rope_theta + return dict( vocab_size=self.hf_config.vocab_size, max_sequence_length=self.hf_config.max_position_embeddings, position_embedding_type="rope", - rotary_base=self.hf_config.rope_theta, + rotary_base=rotary_base, ) def _get_transformer_layer_spec(self, vp_stage: Optional[int] = None): diff --git a/mbridge/core/safetensor_io.py b/mbridge/core/safetensor_io.py index fa1a106..d20cd4a 100644 --- a/mbridge/core/safetensor_io.py +++ b/mbridge/core/safetensor_io.py @@ -14,7 +14,7 @@ class SafeTensorIO: def __init__(self, hf_dir: str): index_file = os.path.join(hf_dir, "model.safetensors.index.json") - config = AutoConfig.from_pretrained(hf_dir) + config = AutoConfig.from_pretrained(hf_dir, trust_remote_code=True) self.index = {} self.origin_index = {} diff --git a/mbridge/models/deepseek_v3.py b/mbridge/models/deepseek_v3.py index b0e423e..6668151 100644 --- a/mbridge/models/deepseek_v3.py +++ b/mbridge/models/deepseek_v3.py @@ -123,7 +123,7 @@ def _build_config(self): mtp_args = {} if "num_nextn_predict_layers" in hf_config: mtp_args["mtp_num_layers"] = hf_config.num_nextn_predict_layers - mtp_args["mtp_loss_scaling_factor"] = 0.1 + mtp_args["mtp_loss_scaling_factor"] = self.extra_args.get("mtp_loss_scaling_factor", 0.1) base_config = { "attention_backend": AttnBackend.fused, diff --git a/mbridge/models/mimo.py b/mbridge/models/mimo.py index e5e43c7..873b1b4 100644 --- a/mbridge/models/mimo.py +++ b/mbridge/models/mimo.py @@ -1,7 +1,8 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +from typing import Callable, Optional - -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec +import torch +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec, get_gpt_decoder_block_spec from ..core import register_model from .qwen2 import Qwen2Bridge @@ -27,7 +28,7 @@ def _build_config(self): mtp_args = {} if "num_nextn_predict_layers" in hf_config: mtp_args["mtp_num_layers"] = hf_config.num_nextn_predict_layers - mtp_args["mtp_loss_scaling_factor"] = 0.1 + mtp_args["mtp_loss_scaling_factor"] = self.extra_args.get("mtp_loss_scaling_factor", 0.1) return self._build_base_config( add_qkv_bias=True, @@ -41,10 +42,11 @@ def _get_gptmodel_args(self) -> dict: # Add MTP block spec if MTP layers are present if self.config.mtp_num_layers is not None: - transformer_layer_spec = self.config - mtp_block_spec = get_gpt_mtp_block_spec( - self.config, transformer_layer_spec, use_transformer_engine=True - ) + transformer_layer_spec = get_gpt_decoder_block_spec(config=self.config, + use_transformer_engine=True) + mtp_block_spec = get_gpt_mtp_block_spec(self.config, + transformer_layer_spec.layer_specs[-1], + True) ret["mtp_block_spec"] = mtp_block_spec return ret @@ -109,3 +111,34 @@ def _convert_mtp_param(self, name: str) -> list[str]: else: raise NotImplementedError(f"Unsupported MTP parameter name: {name}") return convert_names + + + # The default MTP input concatenation order in Megatron is [target_token_embed, hidden_state], + # which is applicable to DeepSeek and GLM models; + # The MTP input format of the Mimo-7B-RL model is [hidden_state, target_token_embed], + # with the order completely reversed; + # It is necessary to swap the parameter values of the first half and the second half of eh_proj here. + # + # mimo mtp sglang code: + # https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/mimo_mtp.py#L61-L69 + # mimo mtp vllm code: + # https://github.com/XiaomiMiMo/vllm/blob/feat_mimo_mtp_stable_073/vllm/model_executor/models/mimo_mtp.py#L84-L85 + # megatron code: + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/multi_token_prediction.py#L636 + + def _weight_to_mcore_format(self, mcore_weights_name: str, hf_weights: list[torch.Tensor]) -> torch.Tensor: + """Swap halves of eh_proj weights before handing off to Megatron-Core.""" + weight = super()._weight_to_mcore_format(mcore_weights_name, hf_weights) + if mcore_weights_name.endswith("eh_proj.weight"): + first_half, second_half = weight.chunk(2, dim=1) + weight = torch.cat([second_half, first_half], dim=1) + return weight + + def _weight_to_hf_format( + self, mcore_weights_name: str, mcore_weights: torch.Tensor + ) -> tuple[list[str], list[torch.Tensor]]: + """Swap halves back when exporting eh_proj weights to HuggingFace format.""" + if mcore_weights_name.endswith("eh_proj.weight"): + first_half, second_half = mcore_weights.chunk(2, dim=1) + mcore_weights = torch.cat([second_half, first_half], dim=1) + return super()._weight_to_hf_format(mcore_weights_name, mcore_weights)