diff --git a/examples/conversion/hf_to_megatron_generate_text.py b/examples/conversion/hf_to_megatron_generate_text.py index 9d0077f201..a9fc388d8a 100644 --- a/examples/conversion/hf_to_megatron_generate_text.py +++ b/examples/conversion/hf_to_megatron_generate_text.py @@ -167,9 +167,9 @@ def main(args) -> None: model_provider.initialize_model_parallel(seed=0) model = model_provider.provide_distributed_model(wrap_with_ddp=False) - # TEMP FIX for inference failure when mtp_num_layers is not None for m in model: - m.config.mtp_num_layers = None + if hasattr(m, "mtp_process"): + m.mtp_process = False model = [m.cuda() for m in model] for m in model: diff --git a/src/megatron/bridge/models/conversion/__init__.py b/src/megatron/bridge/models/conversion/__init__.py index e5071761e7..7b4b493dec 100644 --- a/src/megatron/bridge/models/conversion/__init__.py +++ b/src/megatron/bridge/models/conversion/__init__.py @@ -19,6 +19,8 @@ from megatron.bridge.models.conversion.param_mapping import ( AutoMapping, ColumnParallelMapping, + FusedExpertMapping, + FusedGatedExpertMapping, GatedMLPMapping, MegatronParamMapping, QKVMapping, @@ -33,6 +35,8 @@ "MegatronMappingRegistry", "MegatronModelBridge", "ColumnParallelMapping", + "FusedExpertMapping", + "FusedGatedExpertMapping", "GatedMLPMapping", "MegatronParamMapping", "QKVMapping", diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index c5a6155b12..a58b8c4de5 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -371,7 +371,7 @@ def hf_config_to_provider_kwargs(self, hf_config) -> dict: else: value = getattr(hf_config, hf_name, None) has_value = hasattr(hf_config, hf_name) - if has_value: + if has_value and value is not None: provider_kwargs[megatron_name] = value # Extract rotary_base via compat function (handles both legacy rope_theta @@ -714,6 +714,62 @@ def maybe_modify_converted_hf_weight( """ return converted_weights_dict + def _accumulate_grouped_export( + self, + task: "WeightConversionTask", + converted_weights_dict: Dict[str, torch.Tensor], + model_config, + grouped_buffers: Dict[str, Dict[int, torch.Tensor]], + hf_state_dict: Mapping[str, torch.Tensor], + ) -> Optional[Dict[str, torch.Tensor]]: + """Accumulate per-expert weights for grouped export, return merged result when complete. + + For fused-expert MoE models where one HF tensor contains all experts, this method + collects individual expert weights produced by per-expert ``megatron_to_hf`` calls + and returns the stacked result once all experts have been accumulated. + + Returns: + Merged weights dict when the group is complete, ``None`` otherwise. + """ + from megatron.bridge.utils.common_utils import extract_expert_number_from_param + + group_key = task.mapping.group_key + if group_key not in grouped_buffers: + grouped_buffers[group_key] = {} + + ep_size = parallel_state.get_expert_model_parallel_world_size() + num_experts = model_config.num_moe_experts + experts_per_rank = num_experts // ep_size + + try: + local_expert_number = extract_expert_number_from_param(task.param_name) % experts_per_rank + except ValueError: + return None + + for _, value in converted_weights_dict.items(): + if ep_size == 1: + grouped_buffers[group_key][local_expert_number] = value + else: + if value.ndim > 0 and value.shape[0] == ep_size: + for i in range(ep_size): + global_expert_number = local_expert_number + (i * experts_per_rank) + grouped_buffers[group_key][global_expert_number] = value[i] + else: + grouped_buffers[group_key][local_expert_number] = value + + if len(grouped_buffers[group_key]) == num_experts: + merged = torch.stack([grouped_buffers[group_key][i] for i in range(num_experts)], dim=0) + + if group_key in hf_state_dict: + expected_shape = hf_state_dict[group_key].shape + if merged.shape != expected_shape and merged.transpose(-1, -2).shape == expected_shape: + merged = merged.transpose(-1, -2).contiguous() + + del grouped_buffers[group_key] + return {group_key: merged} + + return None + def load_weights_hf_to_megatron( self, hf_pretrained: HFPreTrained, @@ -777,12 +833,20 @@ def load_weights_hf_to_megatron( hf_state_dict: Mapping[str, torch.Tensor] = hf_pretrained.state if hasattr(hf_pretrained, "state") else {} description = f"Loading from {hf_pretrained.model_name_or_path}" + _hf_import_cache: Dict[str, torch.Tensor] = {} for task in self._with_progress_tracking(hf_to_megatron_tasks, description): # None means megatron module not on current rank, skip if this task is not going to happen if task.megatron_module is None: continue - # 1) Fetch source tensor(s) from HF state dict - hf_weights = self.maybe_modify_loaded_hf_weight(task.mapping.hf_param, hf_state_dict) + # 1) Fetch source tensor(s) from HF state dict, with caching for grouped mappings + hf_param_key = str(task.mapping.hf_param) + is_grouped = getattr(task.mapping, "is_grouped_export", False) + if is_grouped and hf_param_key in _hf_import_cache: + hf_weights = _hf_import_cache[hf_param_key] + else: + hf_weights = self.maybe_modify_loaded_hf_weight(task.mapping.hf_param, hf_state_dict) + if is_grouped: + _hf_import_cache[hf_param_key] = hf_weights # 2) Delegate conversion & distribution to the bridge converted_weights = task.mapping.hf_to_megatron(hf_weights, task.megatron_module) @@ -969,14 +1033,33 @@ def stream_weights_megatron_to_hf( hf_state_dict: Mapping[str, torch.Tensor] = hf_pretrained.state if hasattr(hf_pretrained, "state") else {} + # Pre-compute expected expert counts for grouped export mappings + _grouped_task_counts: Dict[str, int] = {} + for task in megatron_to_hf_tasks: + if task is not None and getattr(task.mapping, "is_grouped_export", False): + gk = task.mapping.group_key + _grouped_task_counts[gk] = _grouped_task_counts.get(gk, 0) + 1 + _grouped_buffers: Dict[str, Dict[int, torch.Tensor]] = {} + for task in self._with_progress_tracking(megatron_to_hf_tasks, "Converting to HuggingFace", show_progress): converted_weights_dict = task.mapping.megatron_to_hf(task.param_weight, task.megatron_module) + + # --- Grouped export path: accumulate per-expert weights, yield when complete --- + if getattr(task.mapping, "is_grouped_export", False): + merged_result = self._accumulate_grouped_export( + task, converted_weights_dict, model_config, _grouped_buffers, hf_state_dict + ) + if merged_result is not None: + for hf_name, tensor in merged_result.items(): + yield HFWeightTuple(hf_name, tensor.cpu() if cpu else tensor) + continue + + # --- Standard export path --- converted_weights_dict = self.maybe_modify_converted_hf_weight( task, converted_weights_dict, hf_state_dict, - ) # dict will be none except for one expert; - # All ranks get the full tensor + ) adapter_tasks = None if merge_adapter_weights and "to_wrap.weight" in task.global_param_name: @@ -984,7 +1067,6 @@ def stream_weights_megatron_to_hf( adapter_tasks = adapter_tasks_by_base.get(task_global_base_prefix) if merge_adapter_weights and adapter_tasks: adapter_weights = self.materialize_adapter_weights(adapter_tasks) - # Merge LoRA adapter weights back into the base tensor for HF export converted_weights_dict = self._merge_lora_adapter_weights( megatron_model, converted_weights_dict, @@ -1000,21 +1082,17 @@ def stream_weights_megatron_to_hf( # Handle tied embeddings case # TODO(yuya): fix this hard coded naming if embeddings_are_tied and hf_name == "model.embed_tokens.weight": - # Yield the embedding weight yield HFWeightTuple(hf_name, final_tensor) - # Also yield as lm_head.weight if it's expected if hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source"): expected_keys = hf_pretrained.state.source.get_all_keys() if "lm_head.weight" in expected_keys: yield HFWeightTuple("lm_head.weight", final_tensor.clone().detach()) elif embeddings_are_tied and hf_name == "lm_head.weight": - # This should not happen when embeddings are tied - assert error raise ValueError( "Encountered lm_head.weight when embeddings are tied. This indicates a mapping error." ) else: - # Regular case - yield the tensor normally yield HFWeightTuple(hf_name, final_tensor) def dtype_from_hf(self, config, default=None): diff --git a/src/megatron/bridge/models/conversion/param_mapping.py b/src/megatron/bridge/models/conversion/param_mapping.py index d0c0d31853..63d8931c92 100644 --- a/src/megatron/bridge/models/conversion/param_mapping.py +++ b/src/megatron/bridge/models/conversion/param_mapping.py @@ -553,7 +553,13 @@ def _count_wildcard_groups(self, pattern: str) -> int: return count def _validate_patterns(self): - """Validate wildcard consistency between patterns.""" + """Validate wildcard consistency between patterns. + + Skipped automatically for grouped-export mappings where the megatron + side intentionally has more wildcards than the HF side. + """ + if getattr(self, "is_grouped_export", False): + return megatron_param_wildcards = self._count_wildcard_groups(self.megatron_param) if isinstance(self.hf_param, str): hf_param_wildcards = self._count_wildcard_groups(self.hf_param) @@ -805,16 +811,17 @@ def hf_to_megatron( if hf_weights is None: raise ValueError("hf_weights should not be None on rank 0") - # For MCore MambaMixer, A_log is initialized in FP32 but cast to BF16 when - # saving ckpts, including the ckpt uploaded to HF. Without this cast, - # self.scatter_to_tp_ranks will try to scatter the HF A_log weights in BF16 to - # the Megatron tensor which is in FP32. This will error. So we cast before the scatter. + # Dtype may differ (e.g. MambaMixer A_log is FP32 in MCore but BF16 + # in HF checkpoints). Cast to match the Megatron parameter so the + # scatter doesn't fail on dtype mismatch. if hf_weights.dtype != target_param.dtype: - logger.warning( - f"WARNING: Dtype mismatch between HuggingFace weights and Megatron module. " - f"HF dtype: {hf_weights.dtype}. Megatron dtype: {target_param.dtype}. " - f"Casting HF weights to Megatron dtype. THIS MAY RESULT IN A LOSS OF PRECISION. " - ) + if not getattr(ColumnParallelMapping, "_dtype_warned", False): + ColumnParallelMapping._dtype_warned = True + logger.warning( + f"Dtype mismatch: HF weights are {hf_weights.dtype} but Megatron " + f"module uses {target_param.dtype}. Casting all mismatched weights " + f"to {target_param.dtype} (further warnings suppressed)." + ) hf_weights = hf_weights.to(target_param.dtype) # For bias (1D), we still split along dim 0 @@ -2212,6 +2219,140 @@ def megatron_to_hf(self, megatron_weights: torch.Tensor, megatron_module: nn.Mod return {key: value} +def _align_expert_weight_to_shape(weight: torch.Tensor, target_shape: torch.Size, name: str) -> torch.Tensor: + """Auto-detect whether a transpose is needed to match the Megatron target shape. + + Handles both transformers <5.0 (transposed) and 5.0+ (standard) expert weight layouts. + """ + if tuple(weight.shape) == tuple(target_shape): + return weight + if weight.ndim == 2 and tuple(weight.t().shape) == tuple(target_shape): + return weight.t().contiguous() + raise ValueError(f"Unexpected {name} shape {tuple(weight.shape)}; expected {tuple(target_shape)}.") + + +class _LooseGatedMLPMapping(GatedMLPMapping): + """GatedMLPMapping that skips wildcard validation for fused expert mappings.""" + + is_grouped_export = True + + +class FusedExpertMapping(AutoMapping): + """Mapping for fused expert weights: 1 HF tensor [num_experts, ...] <-> N Megatron per-expert tensors. + + HF side: Single tensor with shape [num_experts, ...] + Megatron side: Per-expert tensors (one param per expert) + + Import: Extracts single expert from fused HF tensor, auto-aligns shape, + delegates to AutoMapping for TP distribution. + Export: AutoMapping handles TP/EP gathering per expert, then the conversion + loop merges all experts via the ``is_grouped_export`` protocol. + + Replaces per-model expert mapping classes and eliminates the need for + ``maybe_modify_converted_hf_weight`` / ``hf_weights_cache`` on bridges. + """ + + is_grouped_export = True + + def __init__(self, megatron_param: str, hf_param: str, permute_dims: Optional[Tuple[int, ...]] = None): + super().__init__(megatron_param, hf_param, permute_dims) + self.allow_hf_name_mismatch = True + + @property + def group_key(self) -> str: + """Tasks sharing the same group_key are merged during export.""" + return self.hf_param + + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + from megatron.bridge.utils.common_utils import extract_expert_number_from_param + + expert_idx = extract_expert_number_from_param(self.megatron_param) + expert_weight = hf_weights[expert_idx] if hf_weights.ndim >= 3 else hf_weights + + normalized_param = self._normalize_expert_param_name(self.megatron_param) + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) + expert_weight = _align_expert_weight_to_shape(expert_weight, target_param.shape, "expert_weight") + return super().hf_to_megatron(expert_weight, megatron_module) + + +class FusedGatedExpertMapping(AutoMapping): + """Mapping for fused gated expert weights (gate+up projection). + + HF side: Single tensor with shape [num_experts, 2*intermediate, hidden] + Megatron side: Per-expert linear_fc1 tensors (with gate+up interleaved) + + Import: Extracts single expert, splits into gate+up, delegates to + GatedMLPMapping for interleaved TP distribution. + Export: GatedMLPMapping handles TP/EP gathering, gate+up are fused back, + conversion loop merges all experts via the ``is_grouped_export`` protocol. + """ + + is_grouped_export = True + + def __init__(self, megatron_param: str, hf_param: str, permute_dims: Optional[Tuple[int, ...]] = None): + super().__init__(megatron_param, hf_param, permute_dims) + self.allow_hf_name_mismatch = True + self._gated_mapping = _LooseGatedMLPMapping( + megatron_param=self.megatron_param, + gate=f"{self.hf_param}.gate", + up=f"{self.hf_param}.up", + ) + + @property + def group_key(self) -> str: + """Tasks sharing the same group_key are merged during export.""" + return self.hf_param + + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: + from megatron.bridge.utils.common_utils import extract_expert_number_from_param + + expert_idx = extract_expert_number_from_param(self.megatron_param) + expert_weight = hf_weights[expert_idx] if hf_weights.ndim >= 3 else hf_weights + + normalized_param = self._normalize_expert_param_name(self.megatron_param) + _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) + target_shape = target_param.shape + + if target_shape[0] % 2 != 0: + raise ValueError(f"Expected even fused dim for {self.megatron_param}, got {target_shape}.") + + gate_target_shape = (target_shape[0] // 2, target_shape[1]) + + if expert_weight.ndim == 3 and expert_weight.shape[0] == 2: + gate = _align_expert_weight_to_shape(expert_weight[0], gate_target_shape, "gate") + up = _align_expert_weight_to_shape(expert_weight[1], gate_target_shape, "up") + else: + expert_weight = _align_expert_weight_to_shape(expert_weight, target_shape, "gate_up") + gate, up = torch.chunk(expert_weight, 2, dim=0) + + return self._gated_mapping.hf_to_megatron({"gate": gate, "up": up}, megatron_module) + + def megatron_to_hf( + self, + megatron_weights: Optional[torch.Tensor], + megatron_module: Optional[nn.Module], + ) -> Dict[str, torch.Tensor]: + converted = self._gated_mapping.megatron_to_hf(megatron_weights, megatron_module) + if not converted: + return {} + + fused: Dict[str, torch.Tensor] = {} + for name, tensor in converted.items(): + if not name.endswith(".gate"): + continue + base_name = name[: -len(".gate")] + up_tensor = converted.get(f"{base_name}.up") + if up_tensor is None: + continue + concat_dim = 0 if tensor.ndim == 2 else 1 + fused[base_name] = torch.cat([tensor, up_tensor], dim=concat_dim) + return fused + + def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping": + resolved_megatron_param, resolved_hf_param = self._resolve_names(captures) + return type(self)(resolved_megatron_param, resolved_hf_param, self.permute_dims) + + def merge_qkv_biases(config: TransformerConfig, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: """Merge separate Q, K, V bias vectors into Megatron's interleaved QKV format. diff --git a/src/megatron/bridge/models/glm/glm45_bridge.py b/src/megatron/bridge/models/glm/glm45_bridge.py index cbb1ca9238..077b088e92 100644 --- a/src/megatron/bridge/models/glm/glm45_bridge.py +++ b/src/megatron/bridge/models/glm/glm45_bridge.py @@ -16,7 +16,6 @@ from functools import partial import torch -from megatron.core import parallel_state from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.models.gpt.gpt_model import GPTModel from transformers import Glm4MoeForCausalLM @@ -34,7 +33,6 @@ ) from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM -from megatron.bridge.utils.common_utils import extract_expert_number_from_param try: @@ -92,6 +90,7 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider provider.bias_dropout_fusion = True provider.hidden_dropout = 0.0 provider.autocast_dtype = torch.bfloat16 + provider.mtp_num_layers = getattr(hf_config, "num_nextn_predict_layers", None) provider.mtp_loss_scaling_factor = 0.3 provider.moe_shared_expert_intermediate_size = hf_config.moe_intermediate_size @@ -216,7 +215,7 @@ def mapping_registry(self) -> MegatronMappingRegistry: for mtp_layer in range(num_mtp_layers): for megatron_param, hf_param in layer_specific_mappings.items(): megatron_param = ( - megatron_param.replace(".*", ".*.transformer_layer") + megatron_param.replace(".*", ".*.mtp_model_layer") .replace("decoder", "mtp") .replace(".*", f".{mtp_layer}") ) @@ -248,24 +247,24 @@ def mapping_registry(self) -> MegatronMappingRegistry: mapping_list.extend( [ QKVMapping( - megatron_param=f"mtp.layers.{mtp_layer}.transformer_layer.self_attention.linear_qkv.weight", + megatron_param=f"mtp.layers.{mtp_layer}.mtp_model_layer.self_attention.linear_qkv.weight", q=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.q_proj.weight", k=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.k_proj.weight", v=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.v_proj.weight", ), QKVMapping( - megatron_param=f"mtp.layers.{mtp_layer}.transformer_layer.self_attention.linear_qkv.bias", + megatron_param=f"mtp.layers.{mtp_layer}.mtp_model_layer.self_attention.linear_qkv.bias", q=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.q_proj.bias", k=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.k_proj.bias", v=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.v_proj.bias", ), GatedMLPMapping( - megatron_param=f"mtp.layers.{mtp_layer}.transformer_layer.mlp.linear_fc1.weight", + megatron_param=f"mtp.layers.{mtp_layer}.mtp_model_layer.mlp.linear_fc1.weight", gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.linear_fc1.gate.weight", up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.linear_fc1.up.weight", ), GatedMLPMapping( - megatron_param=f"mtp.layers.{mtp_layer}.transformer_layer.mlp.shared_experts.linear_fc1.weight", + megatron_param=f"mtp.layers.{mtp_layer}.mtp_model_layer.mlp.shared_experts.linear_fc1.weight", gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.shared_experts.gate_proj.weight", up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.shared_experts.up_proj.weight", ), @@ -275,18 +274,14 @@ def mapping_registry(self) -> MegatronMappingRegistry: mapping_list.extend( [ GLMExpertGateUpProjMapping( - megatron_param=( - f"mtp.layers.{mtp_layer}.transformer_layer.mlp.experts.linear_fc1.weight*" - ), + megatron_param=(f"mtp.layers.{mtp_layer}.mtp_model_layer.mlp.experts.linear_fc1.weight*"), hf_param=( f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.gate_up_proj" f"{gate_up_suffix}" ), ), GLMExpertDownProjMapping( - megatron_param=( - f"mtp.layers.{mtp_layer}.transformer_layer.mlp.experts.linear_fc2.weight*" - ), + megatron_param=(f"mtp.layers.{mtp_layer}.mtp_model_layer.mlp.experts.linear_fc2.weight*"), hf_param=( f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.down_proj{down_suffix}" ), @@ -297,16 +292,12 @@ def mapping_registry(self) -> MegatronMappingRegistry: mapping_list.extend( [ GatedMLPMapping( - megatron_param=( - f"mtp.layers.{mtp_layer}.transformer_layer.mlp.experts.linear_fc1.weight*" - ), + megatron_param=(f"mtp.layers.{mtp_layer}.mtp_model_layer.mlp.experts.linear_fc1.weight*"), gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.gate_proj.weight", up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.up_proj.weight", ), AutoMapping( - megatron_param=( - f"mtp.layers.{mtp_layer}.transformer_layer.mlp.experts.linear_fc2.weight*" - ), + megatron_param=(f"mtp.layers.{mtp_layer}.mtp_model_layer.mlp.experts.linear_fc2.weight*"), hf_param=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.down_proj.weight", ), ] @@ -338,53 +329,3 @@ def _hf_expert_suffix(self, base_name: str) -> str: return ".weight" return "" - - def maybe_modify_converted_hf_weight( - self, - task, - converted_weights_dict: dict[str, torch.Tensor], - hf_state_dict, - ) -> dict[str, torch.Tensor]: - if not isinstance(task.mapping, (GLMExpertGateUpProjMapping, GLMExpertDownProjMapping)): - return converted_weights_dict - - if not converted_weights_dict: - return {} - - num_experts = self._hf_config.n_routed_experts - ep_size = parallel_state.get_expert_model_parallel_world_size() - experts_per_rank = num_experts // ep_size - - try: - local_expert_number = extract_expert_number_from_param(task.param_name) % experts_per_rank - except ValueError: - return converted_weights_dict - - if not hasattr(self, "hf_weights_cache"): - self.hf_weights_cache = {} - - for key, value in converted_weights_dict.items(): - if key not in self.hf_weights_cache: - self.hf_weights_cache[key] = {} - - if ep_size == 1: - self.hf_weights_cache[key][local_expert_number] = value - else: - if value.shape[0] != ep_size: - raise ValueError(f"Expected EP dim {ep_size} for {key}, got {value.shape}.") - for i, exp_val in enumerate(value): - global_expert_number = local_expert_number + (i * experts_per_rank) - self.hf_weights_cache[key][global_expert_number] = exp_val - - if len(self.hf_weights_cache[key]) == num_experts: - merged = torch.stack([self.hf_weights_cache[key][i] for i in range(num_experts)], dim=0) - if key in hf_state_dict: - expected = hf_state_dict[key].shape - if merged.shape != expected and merged.transpose(-1, -2).shape == expected: - merged = merged.transpose(-1, -2).contiguous() - del self.hf_weights_cache[key] - return {key: merged} - - return {} - - return {} diff --git a/src/megatron/bridge/models/glm/glm_moe_mappings.py b/src/megatron/bridge/models/glm/glm_moe_mappings.py index 5346568e6c..4914825285 100644 --- a/src/megatron/bridge/models/glm/glm_moe_mappings.py +++ b/src/megatron/bridge/models/glm/glm_moe_mappings.py @@ -12,108 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""GLM MoE mapping helpers for fused expert weights in transformers 5.0+.""" - -from typing import Dict - -import torch - -from megatron.bridge.models.conversion.param_mapping import AutoMapping, GatedMLPMapping -from megatron.bridge.models.conversion.utils import get_module_and_param_from_name -from megatron.bridge.utils.common_utils import extract_expert_number_from_param - - -def _select_expert_weight(hf_weights: torch.Tensor, expert_idx: int) -> torch.Tensor: - if hf_weights.ndim >= 3: - return hf_weights[expert_idx] - return hf_weights - - -def _align_weight_to_shape(weight: torch.Tensor, target_shape: torch.Size, name: str) -> torch.Tensor: - if tuple(weight.shape) == tuple(target_shape): - return weight - if weight.ndim == 2 and tuple(weight.t().shape) == tuple(target_shape): - return weight.t().contiguous() - raise ValueError(f"Unexpected {name} shape {tuple(weight.shape)}; expected {tuple(target_shape)}.") - - -class _LooseGatedMLPMapping(GatedMLPMapping): - def _validate_patterns(self, *args, **kwargs): - # Allow mismatched wildcard counts for fused expert mappings. - pass - - -class GLMExpertGateUpProjMapping(AutoMapping): - """Mapping for fused expert gate+up projection weights.""" - - def __init__(self, megatron_param: str, hf_param: str, permute_dims=None): - super().__init__(megatron_param, hf_param, permute_dims) - self._gated_mapping = _LooseGatedMLPMapping( - megatron_param=self.megatron_param, - gate=f"{self.hf_param}.gate", - up=f"{self.hf_param}.up", - ) - - def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: torch.nn.Module) -> torch.Tensor: - global_expert_number = extract_expert_number_from_param(self.megatron_param) - expert_weight = _select_expert_weight(hf_weights, global_expert_number) - - normalized_param = self._normalize_expert_param_name(self.megatron_param) - _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) - target_shape = target_param.shape - gate_target_shape = (target_shape[0] // 2, target_shape[1]) - - if target_shape[0] % 2 != 0: - raise ValueError(f"Expected even fused dim for {self.megatron_param}, got {target_shape}.") - - if expert_weight.ndim == 3 and expert_weight.shape[0] == 2: - gate = _align_weight_to_shape(expert_weight[0], gate_target_shape, "gate") - up = _align_weight_to_shape(expert_weight[1], gate_target_shape, "up") - else: - expert_weight = _align_weight_to_shape(expert_weight, target_shape, "gate_up") - gate, up = torch.chunk(expert_weight, 2, dim=0) - - return self._gated_mapping.hf_to_megatron({"gate": gate, "up": up}, megatron_module) - - def megatron_to_hf( - self, megatron_weights: torch.Tensor, megatron_module: torch.nn.Module - ) -> Dict[str, torch.Tensor]: - converted = self._gated_mapping.megatron_to_hf(megatron_weights, megatron_module) - if not converted: - return {} - - fused: Dict[str, torch.Tensor] = {} - for name, tensor in converted.items(): - if not name.endswith(".gate"): - continue - base_name = name[: -len(".gate")] - up_tensor = converted.get(f"{base_name}.up") - if up_tensor is None: - continue - concat_dim = 0 if tensor.ndim == 2 else 1 - fused[base_name] = torch.cat([tensor, up_tensor], dim=concat_dim) - return fused - - def _validate_patterns(self, *args, **kwargs): - # Allow number of wildcards to mismatch in this mapping. - pass - - -class GLMExpertDownProjMapping(AutoMapping): - """Mapping for fused expert down projection weights.""" - - def __init__(self, megatron_param: str, hf_param: str, permute_dims=None): - super().__init__(megatron_param, hf_param, permute_dims) - - def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: torch.nn.Module) -> torch.Tensor: - global_expert_number = extract_expert_number_from_param(self.megatron_param) - expert_weight = _select_expert_weight(hf_weights, global_expert_number) - - normalized_param = self._normalize_expert_param_name(self.megatron_param) - _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) - expert_weight = _align_weight_to_shape(expert_weight, target_param.shape, "down_proj") - return super().hf_to_megatron(expert_weight, megatron_module) - - def _validate_patterns(self, *args, **kwargs): - # Allow number of wildcards to mismatch in this mapping. - pass +"""GLM MoE mapping helpers for fused expert weights. + +These are thin aliases around the shared FusedExpertMapping / FusedGatedExpertMapping +classes in param_mapping.py. Kept for backwards compatibility with existing imports. +""" + +from megatron.bridge.models.conversion.param_mapping import ( + FusedExpertMapping as GLMExpertDownProjMapping, # noqa: F401 +) +from megatron.bridge.models.conversion.param_mapping import ( + FusedGatedExpertMapping as GLMExpertGateUpProjMapping, # noqa: F401 +) diff --git a/src/megatron/bridge/models/glm_vl/glm_45v_bridge.py b/src/megatron/bridge/models/glm_vl/glm_45v_bridge.py index 2357444d98..5f9c0ab8cf 100644 --- a/src/megatron/bridge/models/glm_vl/glm_45v_bridge.py +++ b/src/megatron/bridge/models/glm_vl/glm_45v_bridge.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict - import torch -from megatron.core import parallel_state from transformers import Glm4vMoeForConditionalGeneration from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry @@ -34,7 +31,6 @@ from megatron.bridge.models.glm_vl.glm_45v_provider import GLM45VModelProvider from megatron.bridge.models.glm_vl.modeling_glm_45v import GLM45VModel from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM -from megatron.bridge.utils.common_utils import extract_expert_number_from_param @MegatronModelBridge.register_bridge(source=Glm4vMoeForConditionalGeneration, target=GLM45VModel) @@ -239,54 +235,3 @@ def _hf_expert_suffix(self, base_name: str) -> str: return ".weight" return "" - - def maybe_modify_converted_hf_weight( - self, - task, - converted_weights_dict: Dict[str, torch.Tensor], - hf_state_dict, - ) -> Dict[str, torch.Tensor]: - if not isinstance(task.mapping, (GLMExpertGateUpProjMapping, GLMExpertDownProjMapping)): - return converted_weights_dict - - if not converted_weights_dict: - return {} - - text_config = getattr(self._hf_config, "text_config", self._hf_config) - num_experts = text_config.n_routed_experts - ep_size = parallel_state.get_expert_model_parallel_world_size() - experts_per_rank = num_experts // ep_size - - try: - local_expert_number = extract_expert_number_from_param(task.param_name) % experts_per_rank - except ValueError: - return converted_weights_dict - - if not hasattr(self, "hf_weights_cache"): - self.hf_weights_cache = {} - - for key, value in converted_weights_dict.items(): - if key not in self.hf_weights_cache: - self.hf_weights_cache[key] = {} - - if ep_size == 1: - self.hf_weights_cache[key][local_expert_number] = value - else: - if value.shape[0] != ep_size: - raise ValueError(f"Expected EP dim {ep_size} for {key}, got {value.shape}.") - for i, exp_val in enumerate(value): - global_expert_number = local_expert_number + (i * experts_per_rank) - self.hf_weights_cache[key][global_expert_number] = exp_val - - if len(self.hf_weights_cache[key]) == num_experts: - merged = torch.stack([self.hf_weights_cache[key][i] for i in range(num_experts)], dim=0) - if key in hf_state_dict: - expected = hf_state_dict[key].shape - if merged.shape != expected and merged.transpose(-1, -2).shape == expected: - merged = merged.transpose(-1, -2).contiguous() - del self.hf_weights_cache[key] - return {key: merged} - - return {} - - return {} diff --git a/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py b/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py index a08d8a3e30..9e3b9b9dc4 100644 --- a/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py +++ b/src/megatron/bridge/models/gpt_oss/gpt_oss_bridge.py @@ -12,19 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import math from dataclasses import fields from typing import Dict, Mapping, Optional, Tuple, Union import torch import torch.nn as nn -from megatron.core import parallel_state from megatron.core.models.gpt.gpt_model import GPTModel from transformers import GptOssForCausalLM from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry -from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge, WeightConversionTask +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge from megatron.bridge.models.conversion.param_mapping import ( AutoMapping, QKVMapping, @@ -55,13 +53,6 @@ class GPTOSSBridge(MegatronModelBridge): >>> provider = bridge.to_megatron_provider() """ - def __init__(self): - super().__init__() - # gpt-oss HF weights has one weight for all the experts, but megatron has one for each expert - # We need to cache the weights during import to load and dequantize the expert weights only once. - # and we need to merge the weights of multiple experts during export. - self.hf_weights_cache = {} - def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider: """Convert HuggingFace config to GPTModelProvider.""" provider = super().provider_bridge(hf_pretrained) @@ -121,74 +112,27 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider def maybe_modify_loaded_hf_weight( self, hf_param: str | dict[str, str], hf_state_dict: Mapping[str, torch.Tensor] ) -> torch.Tensor: - """Load weights from HuggingFace state dict and dequantize if necessary.""" + """Load weights from HuggingFace state dict with MXFP4 dequantization support. + + GPT-OSS stores fused expert weights as [num_experts, hidden, intermediate]. + Megatron expects the per-expert slice as [intermediate, hidden], so we + transpose the 3D tensor once here rather than per-expert in hf_to_megatron. + """ if isinstance(hf_param, str): - if hf_param in self.hf_weights_cache: - return self.hf_weights_cache[hf_param] if hf_param in hf_state_dict: hf_weights = hf_state_dict[hf_param] - if ".mlp.experts." in hf_param and len(hf_weights.shape) == 3: + if ".mlp.experts." in hf_param and hf_weights.ndim == 3: hf_weights = hf_weights.transpose(-1, -2) - self.hf_weights_cache[hf_param] = hf_weights - else: - blocks_key = hf_param + "_blocks" - scales_key = hf_param + "_scales" - if blocks_key in hf_state_dict and scales_key in hf_state_dict: - hf_weights = _dequantize_mxfp4(hf_state_dict[blocks_key], hf_state_dict[scales_key]) - self.hf_weights_cache[hf_param] = hf_weights - else: - raise KeyError( - f"Cannot locate weights for '{hf_param}'. Missing both de-quantized tensor and " - f"quantized representation (blocks='{blocks_key}', scales='{scales_key}')." - ) - else: - hf_weights = {k: hf_state_dict[v] for k, v in hf_param.items()} - return hf_weights - - def maybe_modify_converted_hf_weight( - self, - task: WeightConversionTask, - converted_weights_dict: Dict[str, torch.Tensor], - hf_state_dict: Mapping[str, torch.Tensor], - ) -> Dict[str, torch.Tensor]: - num_experts = self.hf_config.num_local_experts - ep_size = parallel_state.get_expert_model_parallel_world_size() - experts_per_rank = num_experts // ep_size - - try: - local_expert_number = extract_expert_number_from_param(task.param_name) % experts_per_rank - except ValueError: - # not an expert weight - return converted_weights_dict - - assert len(converted_weights_dict) == 1, ( - f"There should be only one key in the converted_weights_dict, got keys: {converted_weights_dict.keys()}" - ) - for key, value in converted_weights_dict.items(): - if key not in self.hf_weights_cache: - self.hf_weights_cache[key] = {} - - # we end up with ep_size many weights to add to the cache - # unpack the weights and re-index - if ep_size == 1: - self.hf_weights_cache[key][local_expert_number] = value - else: - assert value.shape[0] == ep_size - for i, exp_val in enumerate(value): - global_expert_number = local_expert_number + (i * experts_per_rank) - self.hf_weights_cache[key][global_expert_number] = exp_val - if len(self.hf_weights_cache[key]) == num_experts: - logging.debug(f"All experts are loaded for {key}") - # all experts are loaded - merged_hf_weights = torch.cat( - [self.hf_weights_cache[key][i].unsqueeze(0) for i in range(num_experts)], dim=0 - ) - del self.hf_weights_cache[key] - return {key: merged_hf_weights} - else: - # not all experts are loaded yet, return empty dict - logging.debug(f"{len(self.hf_weights_cache[key])}/{num_experts} experts are loaded for {key}") - return {} + return hf_weights + blocks_key = hf_param + "_blocks" + scales_key = hf_param + "_scales" + if blocks_key in hf_state_dict and scales_key in hf_state_dict: + return _dequantize_mxfp4(hf_state_dict[blocks_key], hf_state_dict[scales_key]) + raise KeyError( + f"Cannot locate weights for '{hf_param}'. Missing both de-quantized tensor and " + f"quantized representation (blocks='{blocks_key}', scales='{scales_key}')." + ) + return {k: hf_state_dict[v] for k, v in hf_param.items()} def mapping_registry(self) -> MegatronMappingRegistry: """ @@ -272,42 +216,49 @@ def mapping_registry(self) -> MegatronMappingRegistry: class GPTOSSMLPDownProjMapping(AutoMapping): - """ - MLPDownProj for expert weights GPT-OSS models. + """MLPDownProj for expert weights in GPT-OSS models. + + GPT-OSS stores fc2 weight transposed vs Megatron when using BF16. """ + is_grouped_export = True + def __init__(self, megatron_param: str, hf_param: str, permute_dims: Optional[Tuple[int, ...]] = None): super().__init__(megatron_param, hf_param, permute_dims) self.allow_hf_name_mismatch = True + @property + def group_key(self) -> str: + return self.hf_param + def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: global_expert_number = extract_expert_number_from_param(self.megatron_param) return super().hf_to_megatron(hf_weights[global_expert_number], megatron_module) def megatron_to_hf(self, megatron_weights: torch.Tensor, megatron_module: nn.Module) -> Dict[str, torch.Tensor]: - # only bf16 export is supported currently if megatron_weights is None: return super().megatron_to_hf(megatron_weights, megatron_module) - - # GPT-OSS stores fc2 weight transposed vs Megatron when using BF16. if len(megatron_weights.shape) == 2: megatron_weights = megatron_weights.transpose(0, 1) return super().megatron_to_hf(megatron_weights.contiguous(), megatron_module) - def _validate_patterns(self, *args, **kwargs): - # allow number of wildcards to mismatch in this mapping - pass - class GPTOSSMLPGateUpProjMapping(AutoMapping): - """ - MLPGateUpProj for expert weights GPT-OSS models. + """MLPGateUpProj for expert weights in GPT-OSS models. + + GPT-OSS uses alternating row interleaving for gate/up projections. """ + is_grouped_export = True + def __init__(self, megatron_param: str, hf_param: str, permute_dims: Optional[Tuple[int, ...]] = None): super().__init__(megatron_param, hf_param, permute_dims) self.allow_hf_name_mismatch = True + @property + def group_key(self) -> str: + return self.hf_param + @staticmethod def _interleave(gate_up_proj): return torch.cat((gate_up_proj[::2, ...], gate_up_proj[1::2, ...]), dim=0) @@ -324,19 +275,13 @@ def hf_to_megatron(self, hf_weights: Union[torch.Tensor, Dict], megatron_module: return super().hf_to_megatron(self._interleave(hf_weights[global_expert_number]), megatron_module) def megatron_to_hf(self, megatron_weights: torch.Tensor, megatron_module: nn.Module) -> Dict[str, torch.Tensor]: - # only bf16 export is supported currently if megatron_weights is None: return super().megatron_to_hf(megatron_weights, megatron_module) - megatron_weights = self._uninterleave(megatron_weights) if len(megatron_weights.shape) == 2: megatron_weights = megatron_weights.transpose(0, 1) return super().megatron_to_hf(megatron_weights.contiguous(), megatron_module) - def _validate_patterns(self, *args, **kwargs): - # allow number of wildcards to mismatch in this mapping - pass - def _dequantize_mxfp4( blocks: torch.Tensor, diff --git a/src/megatron/bridge/models/qwen/qwen3_next_bridge.py b/src/megatron/bridge/models/qwen/qwen3_next_bridge.py index 1ec2eeaa3f..7b237e6e81 100755 --- a/src/megatron/bridge/models/qwen/qwen3_next_bridge.py +++ b/src/megatron/bridge/models/qwen/qwen3_next_bridge.py @@ -13,6 +13,9 @@ # limitations under the License. import torch +from megatron.core.models.gpt.experimental_attention_variant_module_specs import ( + get_transformer_block_with_experimental_attention_variant_spec, +) from megatron.core.models.gpt.gpt_model import GPTModel from transformers import Qwen3NextForCausalLM @@ -27,67 +30,69 @@ ReplicatedMapping, RMSNorm2ZeroCenteredRMSNormMapping, ) -from megatron.bridge.models.conversion.transformers_compat import rope_theta_from_hf -from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM from megatron.bridge.models.qwen.qwen_provider import Qwen3NextModelProvider -@MegatronModelBridge.register_bridge(source=Qwen3NextForCausalLM, target=GPTModel) +@MegatronModelBridge.register_bridge(source=Qwen3NextForCausalLM, target=GPTModel, model_type="qwen3_next") class Qwen3NextBridge(MegatronModelBridge): """ - Megatron Hub Bridge for Qwen3 MoE Causal LM. + Megatron Bridge for Qwen3-Next Causal LM. - This bridge handles the conversion between HuggingFace Qwen3MoeForCausalLM - and Megatron-Core GPTModel formats. Qwen3 MoE models use mixture of experts - architecture with QK layernorm. + This bridge handles the conversion between HuggingFace Qwen3NextForCausalLM + and Megatron-Core GPTModel formats. Qwen3-Next uses a hybrid architecture + combining gated delta net linear attention with standard softmax attention, + mixture of experts with shared experts, and zero-centered RMSNorm. Example: >>> from megatron.bridge import AutoBridge - >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-235B-A22B") + >>> bridge = AutoBridge.from_hf_pretrained("Qwen/Qwen3-Next-80B-A3B-Instruct") >>> provider = bridge.to_megatron_provider() """ - def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> Qwen3NextModelProvider: + PROVIDER_CLASS = Qwen3NextModelProvider + + def provider_bridge(self, hf_pretrained): + """Convert HuggingFace Qwen3-Next config to GPTModelProvider.""" + provider = super().provider_bridge(hf_pretrained) hf_config = hf_pretrained.config - provider = Qwen3NextModelProvider( - num_layers=hf_config.num_hidden_layers, - hidden_size=hf_config.hidden_size, - ffn_hidden_size=hf_config.intermediate_size, - moe_ffn_hidden_size=hf_config.moe_intermediate_size, # Maps to moe_intermediate_size in HF - num_attention_heads=hf_config.num_attention_heads, - num_query_groups=hf_config.num_key_value_heads, - num_moe_experts=hf_config.num_experts, - moe_router_topk=hf_config.num_experts_per_tok, # Maps to num_experts_per_tok in HF - init_method_std=hf_config.initializer_range, - layernorm_epsilon=hf_config.rms_norm_eps, - gated_linear_unit=True, - make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size), - rotary_base=rope_theta_from_hf(hf_config), - share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False), - vocab_size=hf_config.vocab_size, - seq_length=hf_config.max_position_embeddings, - fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), - bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), - params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), - qk_layernorm=True, # Qwen3 MoE uses QK layernorm - moe_grouped_gemm=True, - kv_channels=hf_config.head_dim, - # New for Qwen3-Next - layernorm_zero_centered_gamma=True, - attention_output_gate=True, - experimental_attention_variant="gated_delta_net", - linear_attention_freq=hf_config.full_attention_interval, - rotary_percent=hf_config.partial_rotary_factor, - moe_shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size, - moe_shared_expert_gate=True, - linear_conv_kernel_dim=hf_config.linear_conv_kernel_dim, - linear_key_head_dim=hf_config.linear_key_head_dim, - linear_value_head_dim=hf_config.linear_value_head_dim, - linear_num_key_heads=hf_config.linear_num_key_heads, - linear_num_value_heads=hf_config.linear_num_value_heads, - mtp_num_layers=None, # Set to 1 if need MTP - ) + # Standard GPT settings (shared with Qwen3 MoE) + provider.normalization = "RMSNorm" + provider.gated_linear_unit = True + provider.position_embedding_type = "rope" + provider.add_bias_linear = False + provider.add_qkv_bias = False + provider.hidden_dropout = 0.0 + provider.qk_layernorm = True + provider.autocast_dtype = torch.bfloat16 + + # MoE settings + provider.moe_grouped_gemm = True + provider.moe_router_load_balancing_type = "global_aux_loss" + provider.moe_aux_loss_coeff = 1e-3 + provider.moe_router_pre_softmax = False + provider.moe_token_dispatcher_type = "alltoall" + provider.moe_permute_fusion = True + provider.moe_shared_expert_gate = True + provider.moe_router_dtype = "fp32" + provider.moe_shared_expert_intermediate_size = hf_config.shared_expert_intermediate_size + + # Qwen3-Next: zero-centered RMSNorm and gated attention + provider.layernorm_zero_centered_gamma = True + provider.attention_output_gate = True + + # Qwen3-Next: hybrid gated delta net + standard attention + provider.transformer_layer_spec = get_transformer_block_with_experimental_attention_variant_spec + provider.experimental_attention_variant = "gated_delta_net" + provider.linear_attention_freq = hf_config.full_attention_interval + provider.linear_conv_kernel_dim = hf_config.linear_conv_kernel_dim + provider.linear_key_head_dim = hf_config.linear_key_head_dim + provider.linear_value_head_dim = hf_config.linear_value_head_dim + provider.linear_num_key_heads = hf_config.linear_num_key_heads + provider.linear_num_value_heads = hf_config.linear_num_value_heads + + # Heterogeneous checkpointing for mixed attention layers + provider.hetereogenous_dist_checkpoint = True return provider diff --git a/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py b/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py index fd26808bc5..c200ee1731 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py +++ b/src/megatron/bridge/models/qwen_vl/qwen35_vl_bridge.py @@ -39,6 +39,8 @@ from megatron.bridge.models.conversion.param_mapping import ( AutoMapping, ConcatenatedQKVMapping, + FusedExpertMapping, + FusedGatedExpertMapping, GatedMLPMapping, GDNConv1dMapping, GDNLinearMappingSeparate, @@ -48,11 +50,7 @@ ) from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model import Qwen3VLModel -from megatron.bridge.models.qwen_vl.qwen3_vl_bridge import ( - ExpertMLPDownProjMapping, - ExpertMLPGateUpProjMapping, - Qwen3VLMoEBridge, -) +from megatron.bridge.models.qwen_vl.qwen3_vl_bridge import Qwen3VLMoEBridge from megatron.bridge.models.qwen_vl.qwen35_vl_provider import ( Qwen35VLModelProvider, Qwen35VLMoEModelProvider, @@ -328,11 +326,11 @@ def mapping_registry(self) -> MegatronMappingRegistry: # Language Model: MoE Expert MLPs (routed experts) # Uses GatedMLPMapping for gate+up projection fusion # ============================================================= - ExpertMLPGateUpProjMapping( + FusedGatedExpertMapping( megatron_param="language_model.decoder.layers.*.mlp.experts.linear_fc1.weight*", hf_param="model.language_model.layers.*.mlp.experts.gate_up_proj", ), - ExpertMLPDownProjMapping( + FusedExpertMapping( megatron_param="language_model.decoder.layers.*.mlp.experts.linear_fc2.weight*", hf_param="model.language_model.layers.*.mlp.experts.down_proj", ), diff --git a/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py b/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py index 895bfed520..bb6004941e 100644 --- a/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py +++ b/src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py @@ -12,29 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -from typing import Dict, Mapping, Union - import torch -import torch.nn as nn -from megatron.core import parallel_state from transformers import Qwen3VLForConditionalGeneration, Qwen3VLMoeForConditionalGeneration from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry -from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge, WeightConversionTask +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge from megatron.bridge.models.conversion.param_mapping import ( AutoMapping, ConcatenatedQKVMapping, + FusedExpertMapping, + FusedGatedExpertMapping, GatedMLPMapping, QKVMapping, ReplicatedMapping, ) from megatron.bridge.models.conversion.transformers_compat import rope_theta_from_hf -from megatron.bridge.models.conversion.utils import get_module_and_param_from_name from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model import Qwen3VLModel from megatron.bridge.models.qwen_vl.qwen3_vl_provider import Qwen3VLModelProvider, Qwen3VLMoEModelProvider -from megatron.bridge.utils.common_utils import extract_expert_number_from_param @MegatronModelBridge.register_bridge( @@ -232,10 +227,6 @@ class Qwen3VLMoEBridge(MegatronModelBridge): >>> provider = bridge.to_megatron_provider() """ - def __init__(self): - super().__init__() - self.hf_weights_cache: Dict[str, Dict[int, torch.Tensor]] = {} - def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> Qwen3VLMoEModelProvider: hf_config = hf_pretrained.config text_config = hf_config.text_config @@ -377,11 +368,11 @@ def mapping_registry(self) -> MegatronMappingRegistry: k="model.language_model.layers.*.self_attn.k_proj.bias", v="model.language_model.layers.*.self_attn.v_proj.bias", ), - ExpertMLPGateUpProjMapping( + FusedGatedExpertMapping( megatron_param="language_model.decoder.layers.*.mlp.experts.linear_fc1.weight*", hf_param="model.language_model.layers.*.mlp.experts.gate_up_proj", ), - ExpertMLPDownProjMapping( + FusedExpertMapping( megatron_param="language_model.decoder.layers.*.mlp.experts.linear_fc2.weight*", hf_param="model.language_model.layers.*.mlp.experts.down_proj", ), @@ -406,145 +397,3 @@ def mapping_registry(self) -> MegatronMappingRegistry: ) return MegatronMappingRegistry(*mapping_list) - - def maybe_modify_converted_hf_weight( - self, - task: WeightConversionTask, - converted_weights_dict: Dict[str, torch.Tensor], - hf_state_dict: Mapping[str, torch.Tensor], - ) -> Dict[str, torch.Tensor]: - num_experts = self.hf_config.text_config.num_experts - ep_size = parallel_state.get_expert_model_parallel_world_size() - experts_per_rank = num_experts // ep_size - - try: - local_expert_number = extract_expert_number_from_param(task.param_name) % experts_per_rank - except ValueError: - # not an expert weight - return converted_weights_dict - - assert len(converted_weights_dict) == 1, ( - f"There should be only one key in the converted_weights_dict, got keys: {converted_weights_dict.keys()}" - ) - for key, value in converted_weights_dict.items(): - if key not in self.hf_weights_cache: - self.hf_weights_cache[key] = {} - - # we end up with ep_size many weights to add to the cache - # unpack the weights and re-index - if ep_size == 1: - self.hf_weights_cache[key][local_expert_number] = value - else: - assert value.shape[0] == ep_size - for i, exp_val in enumerate(value): - global_expert_number = local_expert_number + (i * experts_per_rank) - self.hf_weights_cache[key][global_expert_number] = exp_val - if len(self.hf_weights_cache[key]) == num_experts: - logging.debug(f"All experts are loaded for {key}") - merged = torch.cat([self.hf_weights_cache[key][i].unsqueeze(0) for i in range(num_experts)], dim=0) - del self.hf_weights_cache[key] - return {key: merged} - else: - # not all experts are loaded yet, return empty dict - logging.debug(f"{len(self.hf_weights_cache[key])}/{num_experts} experts are loaded for {key}") - return {} - - -def _align_weight_to_shape(weight: torch.Tensor, target_shape: torch.Size, name: str) -> torch.Tensor: - """Auto-detect whether a transpose is needed to match the Megatron target shape. - - Transformers <5.0 stored fused expert weights transposed as - [num_experts, hidden_size, 2*intermediate_size], while transformers 5.0+ - uses the standard nn.Linear convention [num_experts, 2*intermediate_size, hidden_size]. - This helper accepts either layout and transposes only when necessary, so the - bridge works with both real checkpoints (old format) and toy models or new - checkpoints created with transformers 5.0+. - """ - if tuple(weight.shape) == tuple(target_shape): - return weight - if weight.ndim == 2 and tuple(weight.t().shape) == tuple(target_shape): - return weight.t().contiguous() - raise ValueError(f"Unexpected {name} shape {tuple(weight.shape)}; expected {tuple(target_shape)}.") - - -class ExpertMLPDownProjMapping(AutoMapping): - """Mapping for expert MLP down projection weights between HF and Megatron formats. - - Uses _align_weight_to_shape so both pre-5.0 (transposed) and 5.0+ - (standard) HF expert weight layouts are handled transparently. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor: - global_expert_number = extract_expert_number_from_param(self.megatron_param) - expert_weight = hf_weights[global_expert_number] if hf_weights.ndim >= 3 else hf_weights - - normalized_param = self._normalize_expert_param_name(self.megatron_param) - _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) - expert_weight = _align_weight_to_shape(expert_weight, target_param.shape, "down_proj") - return super().hf_to_megatron(expert_weight, megatron_module) - - def _validate_patterns(self, *args, **kwargs): - pass - - -class ExpertMLPGateUpProjMapping(AutoMapping): - """Mapping for expert MLP gate+up projection using shared GatedMLPMapping logic. - - Uses _align_weight_to_shape so both pre-5.0 (transposed) and 5.0+ - (standard) HF expert weight layouts are handled transparently. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - GatedMLPMapping._validate_patterns = lambda *args, **kwargs: None - - self._gated_mapping = GatedMLPMapping( - megatron_param=self.megatron_param, - gate=f"{self.hf_param}.gate", - up=f"{self.hf_param}.up", - ) - - def hf_to_megatron(self, hf_weights: Union[torch.Tensor, Dict], megatron_module: nn.Module) -> torch.Tensor: - global_expert_number = extract_expert_number_from_param(self.megatron_param) - expert_weight = hf_weights[global_expert_number] if hf_weights.ndim >= 3 else hf_weights - - normalized_param = self._normalize_expert_param_name(self.megatron_param) - _, target_param = get_module_and_param_from_name(megatron_module, normalized_param) - target_shape = target_param.shape - gate_target_shape = (target_shape[0] // 2, target_shape[1]) - - if target_shape[0] % 2 != 0: - raise ValueError(f"Expected even fused dim for {self.megatron_param}, got {target_shape}.") - - if expert_weight.ndim == 3 and expert_weight.shape[0] == 2: - gate = _align_weight_to_shape(expert_weight[0], gate_target_shape, "gate") - up = _align_weight_to_shape(expert_weight[1], gate_target_shape, "up") - else: - expert_weight = _align_weight_to_shape(expert_weight, target_shape, "gate_up") - gate, up = torch.chunk(expert_weight, 2, dim=0) - - return self._gated_mapping.hf_to_megatron({"gate": gate, "up": up}, megatron_module) - - def megatron_to_hf(self, megatron_weights: torch.Tensor, megatron_module: nn.Module) -> Dict[str, torch.Tensor]: - converted = self._gated_mapping.megatron_to_hf(megatron_weights, megatron_module) - if not converted: - return {} - - fused: Dict[str, torch.Tensor] = {} - for name, tensor in converted.items(): - if not name.endswith(".gate"): - continue - base_name = name[: -len(".gate")] - up_tensor = converted.get(f"{base_name}.up") - if up_tensor is None: - continue - concat_dim = 0 if tensor.ndim == 2 else 1 - fused[base_name] = torch.cat([tensor, up_tensor], dim=concat_dim) - return fused - - def _validate_patterns(self, *args, **kwargs): - pass diff --git a/tests/unit_tests/models/qwen/test_qwen3_next_bridge.py b/tests/unit_tests/models/qwen/test_qwen3_next_bridge.py index e2aef7ff26..210e41d91d 100644 --- a/tests/unit_tests/models/qwen/test_qwen3_next_bridge.py +++ b/tests/unit_tests/models/qwen/test_qwen3_next_bridge.py @@ -80,6 +80,17 @@ def mock_qwen3_next_config(self, qwen3_next_80b_a3b_config_dict): config = Mock() for key, value in qwen3_next_80b_a3b_config_dict.items(): setattr(config, key, value) + for null_attr in ( + "q_lora_rank", + "kv_lora_rank", + "qk_nope_head_dim", + "qk_rope_head_dim", + "v_head_dim", + "n_routed_experts", + "num_local_experts", + "num_nextn_predict_layers", + ): + setattr(config, null_attr, None) return config @pytest.fixture @@ -220,6 +231,17 @@ def test_provider_bridge_dtype_handling(self, qwen3_next_80b_a3b_config_dict): config = Mock() for key, value in qwen3_next_80b_a3b_config_dict.items(): setattr(config, key, value) + for null_attr in ( + "q_lora_rank", + "kv_lora_rank", + "qk_nope_head_dim", + "qk_rope_head_dim", + "v_head_dim", + "n_routed_experts", + "num_local_experts", + "num_nextn_predict_layers", + ): + setattr(config, null_attr, None) config.torch_dtype = "bfloat16" mock_pretrained = Mock(spec=PreTrainedCausalLM) @@ -284,6 +306,17 @@ def test_provider_bridge_80b_a3b_config(self, qwen3_next_80b_a3b_config_dict): config = Mock() for key, value in qwen3_next_80b_a3b_config_dict.items(): setattr(config, key, value) + for null_attr in ( + "q_lora_rank", + "kv_lora_rank", + "qk_nope_head_dim", + "qk_rope_head_dim", + "v_head_dim", + "n_routed_experts", + "num_local_experts", + "num_nextn_predict_layers", + ): + setattr(config, null_attr, None) mock_pretrained = Mock(spec=PreTrainedCausalLM) mock_pretrained.config = config