diff --git a/examples/conversion/compare_hf_and_megatron/compare.py b/examples/conversion/compare_hf_and_megatron/compare.py index 882ea8ff94..fd7ad48fd0 100644 --- a/examples/conversion/compare_hf_and_megatron/compare.py +++ b/examples/conversion/compare_hf_and_megatron/compare.py @@ -91,7 +91,6 @@ """ import argparse -import gc import importlib import os import sys @@ -319,13 +318,7 @@ def vlm_forward_step(data_iterator, model, **kwargs) -> torch.Tensor: def loss_func(x, **kwargs): return x - model_output = model(**forward_args) - if isinstance(model_output, tuple): - output_tensor, _ = model_output - else: - output_tensor = model_output - - return output_tensor, loss_func + return model(**forward_args), loss_func def load_image(image_path: str) -> Image.Image: @@ -616,11 +609,8 @@ def _load_megatron_model(args): model_provider.finalize() megatron_model = model_provider.provide_distributed_model(wrap_with_ddp=False) - # Workaround: disable MTP for inference (causes hangs on NCCL collectives) for m in megatron_model: m.config.mtp_num_layers = None - m.config.grad_scale_func = None - model_components = [m.eval() for m in megatron_model] # Register debug hooks if enabled @@ -727,27 +717,29 @@ def compare_models_one_step(args) -> None: ) del hf_model - gc.collect() torch.cuda.empty_cache() - # Broadcast HF results to all ranks + # Broadcast HF results to all ranks after Megatron initialization + # (following the pattern from generate_from_hf.py) if torch.distributed.is_initialized(): - # Create tensors for broadcasting if they don't exist on non-rank-0 + # Ensure consistent dtype across ranks: rank 0 has bfloat16 logits from the HF model, + # so all ranks must use the same dtype for NCCL broadcast to work correctly. + if hf_logits is not None: + hf_logits = hf_logits.float() + if hf_next_token is None: hf_next_token = torch.zeros(1, device=input_ids.device, dtype=torch.long) if hf_logits is None: - # Get vocab size from tokenizer for proper tensor size vocab_size = getattr( tokenizer, "vocab_size", len(tokenizer.vocab) if hasattr(tokenizer, "vocab") else 32000 ) hf_logits = torch.zeros(vocab_size, device=input_ids.device, dtype=torch.float32) - # Ensure consistent dtype across ranks before broadcast - hf_logits = hf_logits.float() - # Broadcast from rank 0 to all ranks torch.distributed.broadcast(hf_next_token, 0) torch.distributed.broadcast(hf_logits, 0) + torch.distributed.barrier() + print_rank_0("HF results broadcast complete.") # Run Megatron model forward pass print_rank_0("=== RUNNING MEGATRON MODEL (1-STEP) ===") @@ -792,10 +784,7 @@ def compare_models_one_step(args) -> None: megatron_logits = megatron_output[0, -1, :] megatron_next_token = torch.argmax(megatron_logits, dim=-1) - if not torch.distributed.is_initialized() or ( - parallel_state.get_tensor_model_parallel_rank() == 0 - and parallel_state.get_expert_model_parallel_rank() == 0 - ): + if not torch.distributed.is_initialized() or (parallel_state.get_tensor_model_parallel_rank() == 0 and parallel_state.get_expert_model_parallel_rank() == 0): print(f"Megatron output shape: {megatron_output.shape}") print(f"Megatron logits stats - mean: {megatron_logits.mean():.4f}, std: {megatron_logits.std():.4f}") print( @@ -807,27 +796,26 @@ def compare_models_one_step(args) -> None: top5_tokens = [tokenizer.decode([idx]) for idx in top5_ids] print(f"Megatron Top 5: {list(zip(top5_tokens, top5_vals.tolist()))}") - # Compare outputs (only where we have valid Megatron results) + # Megatron may pad vocab_size for GPU kernel efficiency — truncate + # to the HF vocab size so logits are directly comparable. + hf_vocab_size = hf_logits.shape[0] + megatron_logits_cmp = megatron_logits[:hf_vocab_size] + megatron_next_token_cmp = torch.argmax(megatron_logits_cmp, dim=-1) + + # Compare outputs print("=== COMPARISON ===") - token_match = hf_next_token.item() == megatron_next_token.item() + token_match = hf_next_token.item() == megatron_next_token_cmp.item() token_status_emoji = "✅" if token_match else "❌" print(f"Token match: {token_match} {token_status_emoji}") - # Compare logits if shapes match - if hf_logits.shape == megatron_logits.shape: - diff = (hf_logits - megatron_logits).abs() - print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}") - cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits.unsqueeze(0)) - cos_val = cosine_sim.item() - percent = cos_val * 100.0 - status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌" - tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%" - print( - f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)" - ) - else: - print(f"Shape mismatch: HF {hf_logits.shape} vs Megatron {megatron_logits.shape}") - print("Cannot compare logits directly due to shape mismatch") + diff = (hf_logits - megatron_logits_cmp).abs() + print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}") + cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits_cmp.unsqueeze(0)) + cos_val = cosine_sim.item() + percent = cos_val * 100.0 + status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌" + tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%" + print(f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)") print("=== COMPARISON COMPLETE ===") else: