From 2dd50479da1be82aca6f18bf0c2b95c43fc6b658 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 3 Feb 2026 15:59:02 -0800 Subject: [PATCH] Initial commit. --- nemo_rl/models/megatron/__init__.py | 10 + nemo_rl/models/megatron/recipe_config.py | 141 +++++++++++ nemo_rl/models/megatron/setup.py | 219 +++++++++++------- .../policy/workers/megatron_policy_worker.py | 13 ++ 4 files changed, 305 insertions(+), 78 deletions(-) create mode 100644 nemo_rl/models/megatron/recipe_config.py diff --git a/nemo_rl/models/megatron/__init__.py b/nemo_rl/models/megatron/__init__.py index 4fc25d0d3c..f7ce1ab003 100644 --- a/nemo_rl/models/megatron/__init__.py +++ b/nemo_rl/models/megatron/__init__.py @@ -11,3 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from nemo_rl.models.megatron.recipe_config import ( + get_available_recipes, + get_recipe_function, +) + +__all__ = [ + "get_available_recipes", + "get_recipe_function", +] diff --git a/nemo_rl/models/megatron/recipe_config.py b/nemo_rl/models/megatron/recipe_config.py new file mode 100644 index 0000000000..891129c503 --- /dev/null +++ b/nemo_rl/models/megatron/recipe_config.py @@ -0,0 +1,141 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Recipe-based configuration for NeMo-RL Megatron integration. + +This module provides a clean integration with Megatron-Bridge recipes, +allowing NeMo-RL to use pre-configured training recipes as a base and +layer RL-specific settings on top. + +Example usage: + from nemo_rl.models.megatron.recipe_config import create_config_from_recipe + + megatron_cfg = create_config_from_recipe( + hf_model_name="meta-llama/Llama-3.1-8B-Instruct", + policy_config=config, + pretrained_path="/path/to/checkpoint", + weights_path=None, + ) + +Internal flag for testing: + # To use pure recipe settings with minimal RL overrides (for testing): + megatron_cfg = create_config_from_recipe( + ..., + _apply_full_overrides=False, # Internal flag - keeps recipe's optimizer/scheduler + ) +""" + +import warnings +from typing import Any, Callable, Optional + +import torch +from megatron.bridge import AutoBridge +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + DistributedDataParallelConfig, + LoggerConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainingConfig, +) + +from nemo_rl.models.policy import PolicyConfig + + +# ============================================================================= +# RECIPE DISCOVERY +# ============================================================================= + +def _import_llama_recipes(): + """Import Llama recipes from Megatron-Bridge.""" + try: + from megatron.bridge.recipes.llama.llama3 import ( + llama31_8b_pretrain_config, + llama31_70b_pretrain_config, + llama31_405b_pretrain_config, + llama3_8b_pretrain_config, + llama3_70b_pretrain_config, + llama32_1b_pretrain_config, + llama32_3b_pretrain_config, + ) + return { + "llama-3.2-1b": llama32_1b_pretrain_config, + "llama-3.2-3b": llama32_3b_pretrain_config, + "llama-3-8b": llama3_8b_pretrain_config, + "llama-3.1-8b": llama31_8b_pretrain_config, + "meta-llama-3-8b": llama3_8b_pretrain_config, + "meta-llama-3.1-8b": llama31_8b_pretrain_config, + "llama-3-70b": llama3_70b_pretrain_config, + "llama-3.1-70b": llama31_70b_pretrain_config, + "llama-3.1-405b": llama31_405b_pretrain_config, + } + except ImportError: + return {} + + +def _import_qwen_recipes(): + """Import Qwen recipes from Megatron-Bridge.""" + try: + from megatron.bridge.recipes.qwen.qwen3 import ( + qwen3_600m_pretrain_config, + qwen3_1p7b_pretrain_config, + qwen3_4b_pretrain_config, + qwen3_8b_pretrain_config, + ) + return { + "qwen3-0.6b": qwen3_600m_pretrain_config, + "qwen3-1.7b": qwen3_1p7b_pretrain_config, + "qwen3-4b": qwen3_4b_pretrain_config, + "qwen3-8b": qwen3_8b_pretrain_config, + } + except ImportError: + return {} + + +def get_recipe_function(hf_model_name: str) -> Optional[Callable[..., ConfigContainer]]: + """ + Get the appropriate Megatron-Bridge recipe function for a model. + + Args: + hf_model_name: HuggingFace model name or path + + Returns: + Recipe function or None if no matching recipe found + """ + model_lower = hf_model_name.lower().replace("/", "-").replace("_", "-") + + # Load recipes lazily + all_recipes = {} + all_recipes.update(_import_llama_recipes()) + all_recipes.update(_import_qwen_recipes()) + + # Try match + for pattern, recipe_fn in all_recipes.items(): + if pattern in model_lower: + return recipe_fn + + return None + + +def get_available_recipes() -> list[str]: + """Return a list of available recipe patterns.""" + all_recipes = {} + all_recipes.update(_import_llama_recipes()) + all_recipes.update(_import_qwen_recipes()) + return list(all_recipes.keys()) + + diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 24bfdb0605..14dcbabcb3 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -31,6 +31,7 @@ CheckpointConfig, ConfigContainer, DistributedDataParallelConfig, + DistributedInitConfig, LoggerConfig, OptimizerConfig, SchedulerConfig, @@ -68,6 +69,9 @@ from nemo_rl.distributed.named_sharding import NamedSharding from nemo_rl.models.megatron.community_import import import_model_from_hf_name from nemo_rl.models.megatron.config import ModelAndOptimizerState, RuntimeConfig +from nemo_rl.models.megatron.recipe_config import ( + get_recipe_function, +) from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.utils import ( configure_dynamo_cache, @@ -225,7 +229,13 @@ def validate_and_set_config( ) megatron_cfg, model_cfg = setup_model_config( - config, rank, dtype, hf_model_name, pretrained_path, weights_path + config=config, + rank=rank, + dtype=dtype, + hf_model_name=hf_model_name, + pretrained_path=pretrained_path, + weights_path=weights_path, + use_recipe=True, ) final_padded_vocab_size = calculate_padded_vocab_size( @@ -262,7 +272,6 @@ def validate_model_paths(config: PolicyConfig) -> tuple[str, str, bool]: return hf_model_name, pretrained_path, pt_checkpoint_exists - def setup_model_config( config: PolicyConfig, rank, @@ -270,40 +279,53 @@ def setup_model_config( hf_model_name: str, pretrained_path: str, weights_path: Optional[str] = None, + use_recipe: bool = True, ) -> tuple[ConfigContainer, Any]: - """Handle all the model configuration logic.""" - # Load pretrained run config - pretrained_run_config = os.path.join( - pretrained_path, "iter_0000000/run_config.yaml" - ) - - if not os.path.exists(pretrained_run_config): - raise FileNotFoundError( - f"Pretrained run config not found at {pretrained_run_config} on rank={rank}. " - "This usually means that the one-time HF->mcore conversion on rank=0 saved to a directory " - "not being mounted on this node. Please check" + """Setup model configuration.""" + model_cfg = None + use_recipe_for_model = use_recipe and get_recipe_function(hf_model_name) is not None + + if use_recipe_for_model: + # Use Megatron-Bridge golden recipes + print(f"[INFO] Using Megatron-Bridge recipe-based config for {hf_model_name}") + recipe_fn = get_recipe_function(hf_model_name) + if recipe_fn is None: + raise ValueError(f"No recipe found for {hf_model_name}") + + megatron_cfg = recipe_fn() + model_cfg = megatron_cfg.model + else: + # Load pretrained run config + pretrained_run_config = os.path.join( + pretrained_path, "iter_0000000/run_config.yaml" ) - try: - cfg_from_pretrained = ConfigContainer.from_yaml( - pretrained_run_config, mode=InstantiationMode.STRICT - ) - except Exception as e: - # Add helpful context as a note to the exception - e.add_note( - f"\n{'=' * 80}\n" - f"NOTE: A common cause of this error is when the HF->mcore converted checkpoint is\n" - f"created with an older version of megatron-bridge.\n" - f"If this checkpoint is old or was generated by a different code version,\n" - f"try deleting it and rerunning the code.\n" - f"The checkpoint will be automatically regenerated with the current version.\n\n" - f"Checkpoint location: {pretrained_path}\n" - f"{'=' * 80}" - ) - raise + if not os.path.exists(pretrained_run_config): + raise FileNotFoundError( + f"Pretrained run config not found at {pretrained_run_config} on rank={rank}. " + "This usually means that the one-time HF->mcore conversion on rank=0 saved to a directory " + "not being mounted on this node. Please check" + ) + + try: + megatron_cfg = ConfigContainer.from_yaml( + pretrained_run_config, mode=InstantiationMode.STRICT + ) + except Exception as e: + # Add helpful context as a note to the exception + e.add_note( + f"\n{'=' * 80}\n" + f"NOTE: A common cause of this error is when the HF->mcore converted checkpoint is\n" + f"created with an older version of megatron-bridge.\n" + f"If this checkpoint is old or was generated by a different code version,\n" + f"try deleting it and rerunning the code.\n" + f"The checkpoint will be automatically regenerated with the current version.\n\n" + f"Checkpoint location: {pretrained_path}\n" + f"{'=' * 80}" + ) + raise - model_cfg = cfg_from_pretrained.model - cfg_from_pretrained.logger = LoggerConfig() + model_cfg = megatron_cfg.model # Apply parallelism settings _apply_parallelism_config(model_cfg, config) @@ -333,10 +355,8 @@ def setup_model_config( # Validate training configuration _validate_training_config(config, model_cfg) - # Create final megatron config - megatron_cfg = _create_megatron_config( - model_cfg, checkpoint_config, config, hf_model_name, dtype - ) + # Update megatron config with checkpoint, optimizer, scheduler, etc. + _update_megatron_config(megatron_cfg, checkpoint_config, config, hf_model_name) _validate_dtype_config(dtype, megatron_cfg.model, megatron_cfg.optimizer) @@ -570,51 +590,94 @@ def _validate_dtype_config( ) -def _create_megatron_config( - model_cfg: Any, +def _update_dataclass_fields(target: Any, updates: dict) -> None: + """Update a dataclass with values from a dictionary. + + Only sets fields that are present in the updates dict. Fields not in + the dict retain their original values. + + Args: + target: A dataclass instance to update + updates: Dictionary of field names to new values + """ + for key, value in updates.items(): + if hasattr(target, key): + setattr(target, key, value) + + +def _update_megatron_config( + megatron_cfg: ConfigContainer, checkpoint_config: CheckpointConfig, config: PolicyConfig, hf_model_name: str, - dtype: torch.dtype, -) -> ConfigContainer: - """Create the final Megatron configuration container.""" - return ConfigContainer( - model=model_cfg, - checkpoint=checkpoint_config, - logger=LoggerConfig(logging_level=0), - train=TrainingConfig( - micro_batch_size=1, # ignored - global_batch_size=config["train_global_batch_size"], # ignored - train_iters=config["megatron_cfg"]["train_iters"], - ), - optimizer=OptimizerConfig(**config["megatron_cfg"]["optimizer"]), - ddp=DistributedDataParallelConfig( - check_for_nan_in_grad=True, - grad_reduce_in_fp32=config["megatron_cfg"][ - "distributed_data_parallel_config" - ]["grad_reduce_in_fp32"], - overlap_grad_reduce=config["megatron_cfg"][ - "distributed_data_parallel_config" - ]["overlap_grad_reduce"], - overlap_param_gather=config["megatron_cfg"][ - "distributed_data_parallel_config" - ]["overlap_param_gather"], - # we need to set average_in_collective=False with calculate_per_token_loss=T - # otherwise, mcore throws an assertion error. - average_in_collective=False, # Required with calculate_per_token_loss=True - use_distributed_optimizer=config["megatron_cfg"]["optimizer"][ - "use_distributed_optimizer" - ], - data_parallel_sharding_strategy=config["megatron_cfg"][ - "distributed_data_parallel_config" - ]["data_parallel_sharding_strategy"], - ), - scheduler=SchedulerConfig(**config["megatron_cfg"]["scheduler"]), - dataset=None, - tokenizer=TokenizerConfig( - tokenizer_type="HuggingFaceTokenizer", - tokenizer_model=hf_model_name, - ), +) -> None: + """Update the existing ConfigContainer with checkpoint, optimizer, scheduler, and other settings. + + This modifies megatron_cfg in-place. For sub-configs (optimizer, ddp, scheduler, etc.), + only fields explicitly provided in the NeMo-RL config are updated; other fields retain + their original values from the recipe or checkpoint. + """ + megatron_cfg_dict = config.get("megatron_cfg", {}) + + # Ensure dist config is initialized (required for validate()) + if megatron_cfg.dist is None: + megatron_cfg.dist = DistributedInitConfig() + + # Always replace checkpoint config (NeMo-RL manages checkpoints) + megatron_cfg.checkpoint = checkpoint_config + + # Always set logger + megatron_cfg.logger = LoggerConfig(logging_level=0) + + # Update training config - these are NeMo-RL specific + if megatron_cfg.train is None: + megatron_cfg.train = TrainingConfig() + megatron_cfg.train.micro_batch_size = 1 # ignored by NeMo-RL + megatron_cfg.train.global_batch_size = config.get("train_global_batch_size", 1) # ignored by NeMo-RL + if "train_iters" in megatron_cfg_dict: + megatron_cfg.train.train_iters = megatron_cfg_dict["train_iters"] + + # Update optimizer config - merge with existing + optimizer_overrides = megatron_cfg_dict.get("optimizer", {}) + if optimizer_overrides: + if megatron_cfg.optimizer is None: + megatron_cfg.optimizer = OptimizerConfig(**optimizer_overrides) + else: + _update_dataclass_fields(megatron_cfg.optimizer, optimizer_overrides) + + # Update DDP config - merge with existing + ddp_overrides = megatron_cfg_dict.get("distributed_data_parallel_config", {}) + if megatron_cfg.ddp is None: + megatron_cfg.ddp = DistributedDataParallelConfig() + + # Apply explicit DDP overrides from config + if ddp_overrides: + _update_dataclass_fields(megatron_cfg.ddp, ddp_overrides) + + # NeMo-RL required DDP settings (always set) + megatron_cfg.ddp.check_for_nan_in_grad = True + # Required with calculate_per_token_loss=True, otherwise mcore throws assertion error + megatron_cfg.ddp.average_in_collective = False + + # Sync use_distributed_optimizer between optimizer and ddp + if megatron_cfg.optimizer is not None: + megatron_cfg.ddp.use_distributed_optimizer = megatron_cfg.optimizer.use_distributed_optimizer + + # Update scheduler config - merge with existing + scheduler_overrides = megatron_cfg_dict.get("scheduler", {}) + if scheduler_overrides: + if megatron_cfg.scheduler is None: + megatron_cfg.scheduler = SchedulerConfig(**scheduler_overrides) + else: + _update_dataclass_fields(megatron_cfg.scheduler, scheduler_overrides) + + # NeMo-RL handles data separately + megatron_cfg.dataset = None + + # Update tokenizer config - always set for HuggingFace tokenizer + megatron_cfg.tokenizer = TokenizerConfig( + tokenizer_type="HuggingFaceTokenizer", + tokenizer_model=hf_model_name, ) diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 1e68859bd0..be69045401 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -278,6 +278,19 @@ def __init__( self.model, self.optimizer, ) + print("HELLO") + # Dump ConfigContainer to YAML for inspection (only on rank 0) + if self.rank == 0: + config_dump_path = "/lustre/fsw/portfolios/coreai/users/sfawzy/final_megatron_config_6.yaml" + try: + self.megatron_cfg.to_yaml(config_dump_path) + print(f"[DEBUG] Saved final ConfigContainer to: {config_dump_path}") + except Exception as e: + print(f"[WARNING] Failed to save ConfigContainer to YAML: {e}") + # Exit early after dumping config for inspection + import sys + print("[DEBUG] Exiting after ConfigContainer dump") + sys.exit(0) # vars used for refit ## will be initialized in prepare_refit_info