Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/conversion/hf_to_megatron_generate_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def main(args) -> None:
model_provider.initialize_model_parallel(seed=0)
model = model_provider.provide_distributed_model(wrap_with_ddp=False)

# TEMP FIX for inference failure when mtp_num_layers is not None
for m in model:
m.config.mtp_num_layers = None
if hasattr(m, "mtp_process"):
m.mtp_process = False
Comment on lines +171 to +172
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify where MTP is actually controlled in this repository
echo "=== Searching for mtp_process usage ==="
rg -n -C2 '\bmtp_process\b' --type=py

echo -e "\n=== Searching for mtp_num_layers usage ==="
rg -n -C2 '\bmtp_num_layers\b' --type=py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 31680


🏁 Script executed:

cat -n examples/conversion/hf_to_megatron_generate_text.py | head -200 | tail -60

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2666


🏁 Script executed:

# Check what grad_scale_func is and where it's used
echo "=== Searching for grad_scale_func ==="
rg -n 'grad_scale_func' --type=py | head -20

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1656


🏁 Script executed:

# Check if hf_to_megatron_generate_text.py handles any Qwen VL specific logic
echo "=== Context around MTP disable in generate_text.py ==="
rg -B10 -A5 'mtp_process' examples/conversion/hf_to_megatron_generate_text.py

echo -e "\n=== Compare with generate_vlm.py approach ==="
rg -B5 -A5 'mtp_num_layers = None' examples/conversion/hf_to_megatron_generate_vlm.py

echo -e "\n=== Check file imports to see what models are supported ==="
head -50 examples/conversion/hf_to_megatron_generate_text.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 3162


🏁 Script executed:

# Check if Qwen VL is the only model that uses mtp_process
echo "=== Check what models/classes have mtp_process attribute ==="
rg -B5 'self.mtp_process\s*=' --type=py | head -40

echo -e "\n=== Verify if all models have config.mtp_num_layers ==="
rg 'config.mtp_num_layers' examples/conversion --type=py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 957


🏁 Script executed:

# Final check: verify grad_scale_func is also needed for inference
echo "=== Check inference paths that set grad_scale_func to None ==="
rg -B3 -A3 'grad_scale_func = None' examples/conversion src/megatron/bridge/inference --type=py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1934


Set MTP config to None for inference; mtp_process alone is insufficient.

Lines 171–172 only toggle a model-specific attribute. For models where config.mtp_num_layers is set, MTP remains enabled at the config level, which training/inference logic checks to determine whether to use MTP. Additionally, grad_scale_func must be set to None to prevent NCCL collective hangs during inference.

Apply the following fix to match the pattern used in hf_to_megatron_generate_vlm.py and compare_hf_and_megatron/compare.py:

🔧 Proposed fix
 for m in model:
+    if hasattr(m, "config"):
+        m.config.mtp_num_layers = None
+        m.config.grad_scale_func = None
     if hasattr(m, "mtp_process"):
         m.mtp_process = False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if hasattr(m, "mtp_process"):
m.mtp_process = False
if hasattr(m, "config"):
m.config.mtp_num_layers = None
m.config.grad_scale_func = None
if hasattr(m, "mtp_process"):
m.mtp_process = False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/conversion/hf_to_megatron_generate_text.py` around lines 171 - 172,
The current change only flips the model instance flag m.mtp_process, but you
must also disable MTP at the config level and clear mixed-precision scaling to
avoid NCCL hangs: when you see the block that checks hasattr(m, "mtp_process")
and sets m.mtp_process = False, also set m.config.mtp_num_layers = None (or 0 if
config expects an int) and set m.grad_scale_func = None, using attribute
existence checks before assignment to avoid attribute errors; update the same
function/section that handles m.mtp_process so all three changes are applied
together.


model = [m.cuda() for m in model]
for m in model:
Expand Down
4 changes: 4 additions & 0 deletions src/megatron/bridge/models/conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from megatron.bridge.models.conversion.param_mapping import (
AutoMapping,
ColumnParallelMapping,
FusedExpertMapping,
FusedGatedExpertMapping,
GatedMLPMapping,
MegatronParamMapping,
QKVMapping,
Expand All @@ -33,6 +35,8 @@
"MegatronMappingRegistry",
"MegatronModelBridge",
"ColumnParallelMapping",
"FusedExpertMapping",
"FusedGatedExpertMapping",
"GatedMLPMapping",
"MegatronParamMapping",
"QKVMapping",
Expand Down
98 changes: 88 additions & 10 deletions src/megatron/bridge/models/conversion/model_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def hf_config_to_provider_kwargs(self, hf_config) -> dict:
else:
value = getattr(hf_config, hf_name, None)
has_value = hasattr(hf_config, hf_name)
if has_value:
if has_value and value is not None:
provider_kwargs[megatron_name] = value

# Extract rotary_base via compat function (handles both legacy rope_theta
Expand Down Expand Up @@ -714,6 +714,62 @@ def maybe_modify_converted_hf_weight(
"""
return converted_weights_dict

def _accumulate_grouped_export(
self,
task: "WeightConversionTask",
converted_weights_dict: Dict[str, torch.Tensor],
model_config,
grouped_buffers: Dict[str, Dict[int, torch.Tensor]],
hf_state_dict: Mapping[str, torch.Tensor],
) -> Optional[Dict[str, torch.Tensor]]:
"""Accumulate per-expert weights for grouped export, return merged result when complete.

For fused-expert MoE models where one HF tensor contains all experts, this method
collects individual expert weights produced by per-expert ``megatron_to_hf`` calls
and returns the stacked result once all experts have been accumulated.

Returns:
Merged weights dict when the group is complete, ``None`` otherwise.
"""
from megatron.bridge.utils.common_utils import extract_expert_number_from_param

group_key = task.mapping.group_key
if group_key not in grouped_buffers:
grouped_buffers[group_key] = {}

ep_size = parallel_state.get_expert_model_parallel_world_size()
num_experts = model_config.num_moe_experts
experts_per_rank = num_experts // ep_size

try:
local_expert_number = extract_expert_number_from_param(task.param_name) % experts_per_rank
except ValueError:
return None

for _, value in converted_weights_dict.items():
if ep_size == 1:
grouped_buffers[group_key][local_expert_number] = value
else:
if value.ndim > 0 and value.shape[0] == ep_size:
for i in range(ep_size):
global_expert_number = local_expert_number + (i * experts_per_rank)
grouped_buffers[group_key][global_expert_number] = value[i]
else:
grouped_buffers[group_key][local_expert_number] = value

if len(grouped_buffers[group_key]) == num_experts:
merged = torch.stack([grouped_buffers[group_key][i] for i in range(num_experts)], dim=0)

if group_key in hf_state_dict:
expected_shape = hf_state_dict[group_key].shape
if merged.shape != expected_shape and merged.transpose(-1, -2).shape == expected_shape:
merged = merged.transpose(-1, -2).contiguous()

del grouped_buffers[group_key]
return {group_key: merged}

return None

def load_weights_hf_to_megatron(
self,
hf_pretrained: HFPreTrained,
Expand Down Expand Up @@ -777,12 +833,20 @@ def load_weights_hf_to_megatron(
hf_state_dict: Mapping[str, torch.Tensor] = hf_pretrained.state if hasattr(hf_pretrained, "state") else {}

description = f"Loading from {hf_pretrained.model_name_or_path}"
_hf_import_cache: Dict[str, torch.Tensor] = {}
for task in self._with_progress_tracking(hf_to_megatron_tasks, description):
# None means megatron module not on current rank, skip if this task is not going to happen
if task.megatron_module is None:
continue
# 1) Fetch source tensor(s) from HF state dict
hf_weights = self.maybe_modify_loaded_hf_weight(task.mapping.hf_param, hf_state_dict)
# 1) Fetch source tensor(s) from HF state dict, with caching for grouped mappings
hf_param_key = str(task.mapping.hf_param)
is_grouped = getattr(task.mapping, "is_grouped_export", False)
if is_grouped and hf_param_key in _hf_import_cache:
hf_weights = _hf_import_cache[hf_param_key]
else:
hf_weights = self.maybe_modify_loaded_hf_weight(task.mapping.hf_param, hf_state_dict)
if is_grouped:
_hf_import_cache[hf_param_key] = hf_weights

# 2) Delegate conversion & distribution to the bridge
converted_weights = task.mapping.hf_to_megatron(hf_weights, task.megatron_module)
Expand Down Expand Up @@ -969,22 +1033,40 @@ def stream_weights_megatron_to_hf(

hf_state_dict: Mapping[str, torch.Tensor] = hf_pretrained.state if hasattr(hf_pretrained, "state") else {}

# Pre-compute expected expert counts for grouped export mappings
_grouped_task_counts: Dict[str, int] = {}
for task in megatron_to_hf_tasks:
if task is not None and getattr(task.mapping, "is_grouped_export", False):
gk = task.mapping.group_key
_grouped_task_counts[gk] = _grouped_task_counts.get(gk, 0) + 1
_grouped_buffers: Dict[str, Dict[int, torch.Tensor]] = {}

for task in self._with_progress_tracking(megatron_to_hf_tasks, "Converting to HuggingFace", show_progress):
converted_weights_dict = task.mapping.megatron_to_hf(task.param_weight, task.megatron_module)

# --- Grouped export path: accumulate per-expert weights, yield when complete ---
if getattr(task.mapping, "is_grouped_export", False):
merged_result = self._accumulate_grouped_export(
task, converted_weights_dict, model_config, _grouped_buffers, hf_state_dict
)
if merged_result is not None:
for hf_name, tensor in merged_result.items():
yield HFWeightTuple(hf_name, tensor.cpu() if cpu else tensor)
continue

# --- Standard export path ---
converted_weights_dict = self.maybe_modify_converted_hf_weight(
task,
converted_weights_dict,
hf_state_dict,
) # dict will be none except for one expert;
# All ranks get the full tensor
)

adapter_tasks = None
if merge_adapter_weights and "to_wrap.weight" in task.global_param_name:
task_global_base_prefix, _, _ = task.global_param_name.partition(".to_wrap.weight")
adapter_tasks = adapter_tasks_by_base.get(task_global_base_prefix)
if merge_adapter_weights and adapter_tasks:
adapter_weights = self.materialize_adapter_weights(adapter_tasks)
# Merge LoRA adapter weights back into the base tensor for HF export
converted_weights_dict = self._merge_lora_adapter_weights(
megatron_model,
converted_weights_dict,
Expand All @@ -1000,21 +1082,17 @@ def stream_weights_megatron_to_hf(
# Handle tied embeddings case
# TODO(yuya): fix this hard coded naming
if embeddings_are_tied and hf_name == "model.embed_tokens.weight":
# Yield the embedding weight
yield HFWeightTuple(hf_name, final_tensor)

# Also yield as lm_head.weight if it's expected
if hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source"):
expected_keys = hf_pretrained.state.source.get_all_keys()
if "lm_head.weight" in expected_keys:
yield HFWeightTuple("lm_head.weight", final_tensor.clone().detach())
elif embeddings_are_tied and hf_name == "lm_head.weight":
# This should not happen when embeddings are tied - assert error
raise ValueError(
"Encountered lm_head.weight when embeddings are tied. This indicates a mapping error."
)
else:
# Regular case - yield the tensor normally
yield HFWeightTuple(hf_name, final_tensor)

def dtype_from_hf(self, config, default=None):
Expand Down
161 changes: 151 additions & 10 deletions src/megatron/bridge/models/conversion/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,13 @@ def _count_wildcard_groups(self, pattern: str) -> int:
return count

def _validate_patterns(self):
"""Validate wildcard consistency between patterns."""
"""Validate wildcard consistency between patterns.

Skipped automatically for grouped-export mappings where the megatron
side intentionally has more wildcards than the HF side.
"""
if getattr(self, "is_grouped_export", False):
return
megatron_param_wildcards = self._count_wildcard_groups(self.megatron_param)
if isinstance(self.hf_param, str):
hf_param_wildcards = self._count_wildcard_groups(self.hf_param)
Expand Down Expand Up @@ -805,16 +811,17 @@ def hf_to_megatron(
if hf_weights is None:
raise ValueError("hf_weights should not be None on rank 0")

# For MCore MambaMixer, A_log is initialized in FP32 but cast to BF16 when
# saving ckpts, including the ckpt uploaded to HF. Without this cast,
# self.scatter_to_tp_ranks will try to scatter the HF A_log weights in BF16 to
# the Megatron tensor which is in FP32. This will error. So we cast before the scatter.
# Dtype may differ (e.g. MambaMixer A_log is FP32 in MCore but BF16
# in HF checkpoints). Cast to match the Megatron parameter so the
# scatter doesn't fail on dtype mismatch.
if hf_weights.dtype != target_param.dtype:
logger.warning(
f"WARNING: Dtype mismatch between HuggingFace weights and Megatron module. "
f"HF dtype: {hf_weights.dtype}. Megatron dtype: {target_param.dtype}. "
f"Casting HF weights to Megatron dtype. THIS MAY RESULT IN A LOSS OF PRECISION. "
)
if not getattr(ColumnParallelMapping, "_dtype_warned", False):
ColumnParallelMapping._dtype_warned = True
logger.warning(
f"Dtype mismatch: HF weights are {hf_weights.dtype} but Megatron "
f"module uses {target_param.dtype}. Casting all mismatched weights "
f"to {target_param.dtype} (further warnings suppressed)."
)
hf_weights = hf_weights.to(target_param.dtype)

# For bias (1D), we still split along dim 0
Expand Down Expand Up @@ -2212,6 +2219,140 @@ def megatron_to_hf(self, megatron_weights: torch.Tensor, megatron_module: nn.Mod
return {key: value}


def _align_expert_weight_to_shape(weight: torch.Tensor, target_shape: torch.Size, name: str) -> torch.Tensor:
"""Auto-detect whether a transpose is needed to match the Megatron target shape.

Handles both transformers <5.0 (transposed) and 5.0+ (standard) expert weight layouts.
"""
if tuple(weight.shape) == tuple(target_shape):
return weight
if weight.ndim == 2 and tuple(weight.t().shape) == tuple(target_shape):
return weight.t().contiguous()
raise ValueError(f"Unexpected {name} shape {tuple(weight.shape)}; expected {tuple(target_shape)}.")


class _LooseGatedMLPMapping(GatedMLPMapping):
"""GatedMLPMapping that skips wildcard validation for fused expert mappings."""

is_grouped_export = True


class FusedExpertMapping(AutoMapping):
"""Mapping for fused expert weights: 1 HF tensor [num_experts, ...] <-> N Megatron per-expert tensors.

HF side: Single tensor with shape [num_experts, ...]
Megatron side: Per-expert tensors (one param per expert)

Import: Extracts single expert from fused HF tensor, auto-aligns shape,
delegates to AutoMapping for TP distribution.
Export: AutoMapping handles TP/EP gathering per expert, then the conversion
loop merges all experts via the ``is_grouped_export`` protocol.

Replaces per-model expert mapping classes and eliminates the need for
``maybe_modify_converted_hf_weight`` / ``hf_weights_cache`` on bridges.
"""

is_grouped_export = True

def __init__(self, megatron_param: str, hf_param: str, permute_dims: Optional[Tuple[int, ...]] = None):
super().__init__(megatron_param, hf_param, permute_dims)
self.allow_hf_name_mismatch = True

@property
def group_key(self) -> str:
"""Tasks sharing the same group_key are merged during export."""
return self.hf_param

def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor:
from megatron.bridge.utils.common_utils import extract_expert_number_from_param

expert_idx = extract_expert_number_from_param(self.megatron_param)
expert_weight = hf_weights[expert_idx] if hf_weights.ndim >= 3 else hf_weights

normalized_param = self._normalize_expert_param_name(self.megatron_param)
_, target_param = get_module_and_param_from_name(megatron_module, normalized_param)
expert_weight = _align_expert_weight_to_shape(expert_weight, target_param.shape, "expert_weight")
return super().hf_to_megatron(expert_weight, megatron_module)


class FusedGatedExpertMapping(AutoMapping):
"""Mapping for fused gated expert weights (gate+up projection).

HF side: Single tensor with shape [num_experts, 2*intermediate, hidden]
Megatron side: Per-expert linear_fc1 tensors (with gate+up interleaved)

Import: Extracts single expert, splits into gate+up, delegates to
GatedMLPMapping for interleaved TP distribution.
Export: GatedMLPMapping handles TP/EP gathering, gate+up are fused back,
conversion loop merges all experts via the ``is_grouped_export`` protocol.
"""

is_grouped_export = True

def __init__(self, megatron_param: str, hf_param: str, permute_dims: Optional[Tuple[int, ...]] = None):
super().__init__(megatron_param, hf_param, permute_dims)
self.allow_hf_name_mismatch = True
self._gated_mapping = _LooseGatedMLPMapping(
megatron_param=self.megatron_param,
gate=f"{self.hf_param}.gate",
up=f"{self.hf_param}.up",
)

@property
def group_key(self) -> str:
"""Tasks sharing the same group_key are merged during export."""
return self.hf_param

def hf_to_megatron(self, hf_weights: torch.Tensor, megatron_module: nn.Module) -> torch.Tensor:
from megatron.bridge.utils.common_utils import extract_expert_number_from_param

expert_idx = extract_expert_number_from_param(self.megatron_param)
expert_weight = hf_weights[expert_idx] if hf_weights.ndim >= 3 else hf_weights

normalized_param = self._normalize_expert_param_name(self.megatron_param)
_, target_param = get_module_and_param_from_name(megatron_module, normalized_param)
target_shape = target_param.shape

if target_shape[0] % 2 != 0:
raise ValueError(f"Expected even fused dim for {self.megatron_param}, got {target_shape}.")

gate_target_shape = (target_shape[0] // 2, target_shape[1])

if expert_weight.ndim == 3 and expert_weight.shape[0] == 2:
gate = _align_expert_weight_to_shape(expert_weight[0], gate_target_shape, "gate")
up = _align_expert_weight_to_shape(expert_weight[1], gate_target_shape, "up")
else:
expert_weight = _align_expert_weight_to_shape(expert_weight, target_shape, "gate_up")
gate, up = torch.chunk(expert_weight, 2, dim=0)

return self._gated_mapping.hf_to_megatron({"gate": gate, "up": up}, megatron_module)

def megatron_to_hf(
self,
megatron_weights: Optional[torch.Tensor],
megatron_module: Optional[nn.Module],
) -> Dict[str, torch.Tensor]:
converted = self._gated_mapping.megatron_to_hf(megatron_weights, megatron_module)
if not converted:
return {}

fused: Dict[str, torch.Tensor] = {}
for name, tensor in converted.items():
if not name.endswith(".gate"):
continue
base_name = name[: -len(".gate")]
up_tensor = converted.get(f"{base_name}.up")
if up_tensor is None:
continue
concat_dim = 0 if tensor.ndim == 2 else 1
fused[base_name] = torch.cat([tensor, up_tensor], dim=concat_dim)
return fused

def resolve(self, captures: Tuple[str, ...]) -> "MegatronParamMapping":
resolved_megatron_param, resolved_hf_param = self._resolve_names(captures)
return type(self)(resolved_megatron_param, resolved_hf_param, self.permute_dims)


def merge_qkv_biases(config: TransformerConfig, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Merge separate Q, K, V bias vectors into Megatron's interleaved QKV format.

Expand Down
Loading
Loading