diff --git a/examples/nemo_gym/grpo_wordle_nemotron_nano_v2_9b.yaml b/examples/nemo_gym/grpo_wordle_nemotron_nano_v2_9b.yaml new file mode 100644 index 0000000000..f7c6380f20 --- /dev/null +++ b/examples/nemo_gym/grpo_wordle_nemotron_nano_v2_9b.yaml @@ -0,0 +1,278 @@ +# GRPO Training Configuration for Wordle with Nemotron Nano V2 9B +# +# This configuration trains nvidia/NVIDIA-Nemotron-Nano-9B-v2 on the Wordle +# word-guessing game using GRPO (Group Relative Policy Optimization). +# +# Backend: DTensor V2 / Automodel +# Hardware: 4x A100 80GB GPUs +# +# Usage: +# uv run python examples/run_grpo_nemo_gym.py \ +# --config examples/nemo_gym/grpo_wordle_nemotron_nano_v2_9b.yaml + +grpo: + max_num_epochs: 999999 # Effectively unlimited, controlled by max_num_steps + num_prompts_per_step: 64 + num_generations_per_prompt: 16 + max_rollout_turns: 6 # Wordle is 1 turn but allows up to 6 tool-calling steps + max_num_steps: 1000000 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 10 + val_at_start: true + overlong_filtering: false + max_val_samples: null + val_batch_size: null + seed: 88 + use_dynamic_sampling: false + dynamic_sampling_max_gen_batches: 10 + batch_multiplier: 1 + reward_shaping: + enabled: false + overlong_buffer_length: 128 + overlong_buffer_penalty: 1 + max_response_length: ${policy.max_total_sequence_length} + reward_scaling: + enabled: false + source_min: 0.0 + source_max: 2.0 # Max win reward is 2.0 + target_min: 0.0 + target_max: 1.0 + skip_reference_policy_logprobs_calculation: false # Need ref logprobs for KL penalty + +loss_fn: + reference_policy_kl_penalty: 0.01 # Small KL penalty to prevent model collapse + reference_policy_kl_type: "k3" + kl_input_clamp_value: 20.0 + kl_output_clamp_value: 10.0 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 + ratio_clip_c: null + use_on_policy_kl_approximation: false + truncated_importance_sampling_ratio: null + use_importance_sampling_correction: false + token_level_loss: true + +checkpointing: + enabled: true + checkpoint_dir: "results/grpo-wordle-nemo-gym" + # Validation metrics (accuracy = mean reward): + # - accuracy: Mean reward (Win=1.0-2.0, Loss=0.0) + # - wordle_simple_agent/won/mean: Win rate + # - wordle_simple_agent/turns_if_won/sum / won/sum = avg turns to win + metric_name: "val:accuracy" + higher_is_better: true + keep_top_k: 1 + save_period: 5 + checkpoint_must_save_by: null + +policy: + model_name: "nvidia/NVIDIA-Nemotron-Nano-9B-v2" + tokenizer: + name: ${policy.model_name} + chat_template_kwargs: null + hf_config_overrides: {} + train_global_batch_size: ${mul:${grpo.num_prompts_per_step}, ${grpo.num_generations_per_prompt}} + train_micro_batch_size: 2 + logprob_batch_size: 4 + generation_batch_size: 32 + max_total_sequence_length: 2048 + precision: "bfloat16" + logprob_chunk_size: null + + # DTensor V2 / Automodel Configuration + dtensor_cfg: + _v2: true # Enable DTensor V2 / Automodel + enabled: true + cpu_offload: false + sequence_parallel: false + activation_checkpointing: true + tensor_parallel_size: 1 + context_parallel_size: 1 + custom_parallel_plan: null + clear_cache_every_n_steps: null + + # Disable Megatron (using DTensor V2 instead) + megatron_cfg: + enabled: false + bias_activation_fusion: false + tensor_model_parallel_size: 1 + empty_unused_memory_level: 0 + activation_checkpointing: true + train_iters: 100000 + expert_tensor_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + num_layers_in_first_pipeline_stage: null + num_layers_in_last_pipeline_stage: null + context_parallel_size: 1 + pipeline_dtype: ${policy.precision} + sequence_parallel: false + freeze_moe_router: true + moe_router_dtype: "fp64" + moe_router_load_balancing_type: "none" + moe_router_bias_update_rate: 0.0 + apply_rope_fusion: true + defer_fp32_logits: false + moe_permute_fusion: false + + optimizer: + optimizer: "adam" + lr: 5.0e-6 + min_lr: 5.0e-7 + weight_decay: 0.01 + bf16: true + fp16: false + params_dtype: "float32" + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + sgd_momentum: 0.9 + use_distributed_optimizer: true + use_precision_aware_optimizer: true + optimizer_cpu_offload: false + optimizer_offload_fraction: 0.0 + clip_grad: ${policy.max_grad_norm} + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: 100000 + lr_warmup_iters: 13 + lr_warmup_init: 5.0e-7 + override_opt_param_scheduler: true + + distributed_data_parallel_config: + grad_reduce_in_fp32: false + overlap_grad_reduce: true + overlap_param_gather: true + use_custom_fsdp: false + data_parallel_sharding_strategy: "optim_grads_params" + + env_vars: null + + dynamic_batching: + enabled: false + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + sequence_length_round: 64 + + sequence_packing: + enabled: false + train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} + logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} + algorithm: "modified_first_fit_decreasing" + sequence_length_round: 64 + + make_sequence_length_divisible_by: 1 + max_grad_norm: 1.0 + offload_optimizer_for_logprob: false + + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + foreach: False + fused: False + + scheduler: + - name: "torch.optim.lr_scheduler.ConstantLR" + kwargs: + factor: 1.0 + total_iters: 10000000000 + - milestones: [] + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: true + precision: ${policy.precision} + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + enable_expert_parallel: false + expert_parallel_size: 1 + gpu_memory_utilization: 0.7 # Lower to leave room for training + max_model_len: ${policy.max_total_sequence_length} + enforce_eager: false + use_deep_gemm: false + num_last_layers_in_bf16: 0 + num_first_layers_in_bf16: 0 + kv_cache_dtype: "auto" + expose_http_server: true + skip_tokenizer_init: false + tool_parser_plugin: nemo_rl/models/generation/vllm/nemotron_json_tool_parser.py + http_server_serving_chat_kwargs: + enable_auto_tools: true + tool_parser: nemotron_json + chat_template: null # Use model's default template; custom template not needed with source code workarounds + vllm_kwargs: + compilation_config: + use_inductor: false + # Required for Nemotron Nano v2 + mamba_ssm_cache_dtype: "float32" + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null + +data: + train_jsonl_fpath: 3rdparty/Gym-workspace/Gym/resources_servers/wordle/data/train.jsonl + validation_jsonl_fpath: 3rdparty/Gym-workspace/Gym/resources_servers/wordle/data/validation.jsonl + agent_name: wordle_simple_agent + shuffle: true + num_workers: 0 + +env: + should_use_nemo_gym: true + should_log_nemo_gym_responses: true + nemo_gym: + config_paths: + - responses_api_models/vllm_model/configs/vllm_model_for_training.yaml + - resources_servers/wordle/configs/wordle.yaml + wordle_simple_agent: + responses_api_agents: + wordle: + max_steps: 12 # 6 turns × 2 tool calls max per turn + policy_model: + responses_api_models: + vllm_model: + # Disable reasoning! + uses_reasoning_parser: false + extra_body: + chat_template_kwargs: + enable_thinking: false + +logger: + log_dir: "logs/grpo-wordle-nemotron-nano-v2-9b" + num_val_samples_to_print: 5 # Print some validation samples to see game play + wandb_enabled: true + tensorboard_enabled: false + mlflow_enabled: false + swanlab_enabled: false + monitor_gpus: true + wandb: + project: "grpo-wordle" + name: "nemotron-nano-v2-9b-wordle" + tensorboard: {} + mlflow: + experiment_name: "grpo-wordle" + run_name: "nemotron-nano-v2-9b-wordle" + gpu_monitoring: + collection_interval: 10 + flush_interval: 10 + +cluster: + gpus_per_node: 4 # 4 A100 80GB GPUs + num_nodes: 1 diff --git a/examples/run_grpo_wordle.py b/examples/run_grpo_wordle.py new file mode 100644 index 0000000000..a1c98b829f --- /dev/null +++ b/examples/run_grpo_wordle.py @@ -0,0 +1,302 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GRPO training script for Wordle game. + +Trains an LLM to play Wordle using GRPO with dense reward shaping. +The model learns to guess 5-letter words based on feedback patterns. + +Usage: + # Smoke test with small model + uv run python examples/run_grpo_wordle.py \ + grpo.max_num_steps=5 \ + grpo.num_prompts_per_step=4 \ + grpo.num_generations_per_prompt=4 \ + policy.model_name="Qwen/Qwen2.5-0.5B-Instruct" \ + cluster.gpus_per_node=1 + + # Full training with Nemotron-Nano-9B-v2 + uv run python examples/run_grpo_wordle.py \ + logger.wandb_enabled=True \ + logger.wandb.name="wordle-nemotron-9b" +""" + +import argparse +import itertools +import os +import pprint +import random +from typing import Any, Iterator + +from omegaconf import OmegaConf +from torch.utils.data import IterableDataset +from transformers import AutoTokenizer + +from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup +from nemo_rl.algorithms.utils import get_tokenizer, set_seed +from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.environments.games.wordle import ( + WordleConfig, + WordleEnv, + WordleGameLogic, + WordleMetadata, +) +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import get_next_experiment_dir + +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run GRPO training for Wordle") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + args, overrides = parser.parse_known_args() + return args, overrides + + +def generate_wordle_datum( + tokenizer: AutoTokenizer, + game_config: WordleConfig, + max_turns: int, + task_name: str, + idx: int, + add_system_prompt: bool, +) -> DatumSpec: + """Generate a single Wordle game datum (prompt and metadata).""" + # Generate initial game state + initial_game_state = WordleGameLogic.generate(game_config) + target_word = initial_game_state["target_word"] + welcome_message = WordleGameLogic.init(initial_game_state) + + # Create initial prompt - end with to prime the model + prompt_content = f"{welcome_message}\nYour guess: " + + # Apply chat template + initial_prompt_content = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt_content}], + tokenize=False, + add_system_prompt=add_system_prompt, + add_generation_prompt=True, + add_special_tokens=False, + ).strip() + + # Tokenize + tokenized_prompt = tokenizer( + initial_prompt_content, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + + # Create message log + message_log: LLMMessageLogType = [ + { + "role": "user", + "content": initial_prompt_content, + "token_ids": tokenized_prompt, + } + ] + + # Create initial metadata + metadata = WordleMetadata( + target_word=target_word, + guesses=[], + feedback_history=[], + known_greens={}, + known_yellows=set(), + eliminated_letters=set(), + turn=0, + max_turns=max_turns, + ) + + datum: DatumSpec = { + "message_log": message_log, + "length": len(tokenized_prompt), + "extra_env_info": metadata, + "loss_multiplier": 1.0, + "idx": idx, + "task_name": task_name, + "stop_strings": [""], + } + + return datum + + +class IterableWordleDataset(IterableDataset): + """An IterableDataset that generates Wordle games indefinitely.""" + + def __init__( + self, + tokenizer: AutoTokenizer, + game_config: WordleConfig, + max_turns: int, + task_name: str, + add_system_prompt: bool, + length: int, + ): + super().__init__() + self.tokenizer = tokenizer + self.game_config = game_config + self.max_turns = max_turns + self.task_name = task_name + self.add_system_prompt = add_system_prompt + self.length = length + + def __iter__(self) -> Iterator[DatumSpec]: + print("Starting IterableWordleDataset (indefinite generation).") + for i in itertools.count(): + yield generate_wordle_datum( + tokenizer=self.tokenizer, + game_config=self.game_config, + max_turns=self.max_turns, + task_name=self.task_name, + idx=i, + add_system_prompt=self.add_system_prompt, + ) + + def __len__(self): + return self.length + + +def setup_wordle_data( + tokenizer: AutoTokenizer, + env_cfg: dict[str, Any], + task_name: str, + length: int, + val_length: int, + add_system_prompt: bool, +) -> tuple[IterableDataset, IterableDataset | None, dict, dict]: + """Set up the iterable data generator and env map for Wordle.""" + print("Setting up Wordle iterable data and environment...") + env_config = env_cfg[task_name] + + print(f"Instantiating environment for task '{task_name}'...") + env = WordleEnv.options(num_gpus=0).remote(cfg=dict(env_config["cfg"])) + task_to_env = {task_name: env} + print(f"Environment '{task_name}' created.") + + print("Creating Wordle dataset...") + training_dataset = IterableWordleDataset( + tokenizer=tokenizer, + game_config=dict(env_config["cfg"]["game_config"]), + max_turns=env_config["cfg"]["max_turns"], + task_name=task_name, + add_system_prompt=add_system_prompt, + length=length, + ) + print("Wordle training dataset created.") + + validation_dataset = IterableWordleDataset( + tokenizer=tokenizer, + game_config=dict(env_config["cfg"]["game_config"]), + max_turns=env_config["cfg"]["max_turns"], + task_name=task_name, + add_system_prompt=add_system_prompt, + length=val_length, + ) + print("Wordle validation dataset created.") + + val_task_to_env = task_to_env + + return training_dataset, validation_dataset, task_to_env, val_task_to_env + + +def main(): + """Main entry point.""" + args, overrides = parse_args() + + # Default config path + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "grpo_wordle.yaml" + ) + + config = load_config(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + print(f"Overrides: {overrides}") + config = parse_hydra_overrides(config, overrides) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + # Get the next experiment directory + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + print(f"Using log directory: {config['logger']['log_dir']}") + if config["checkpointing"]["enabled"]: + print(f"Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}") + + init_ray() + + set_seed(config["grpo"]["seed"]) + + # Setup tokenizer + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + config["policy"]["generation"] = configure_generation_config( + config["policy"]["generation"], tokenizer + ) + + # Setup data & env map + ds_length = ( + config["grpo"]["num_prompts_per_step"] + * config["grpo"]["num_generations_per_prompt"] + * config["grpo"]["max_num_steps"] + ) + dataset, val_dataset, task_to_env, val_task_to_env = setup_wordle_data( + tokenizer=tokenizer, + env_cfg=config["env"], + task_name="wordle_game", + length=ds_length, + val_length=config["grpo"]["max_val_samples"], + add_system_prompt=config["data"]["add_system_prompt"], + ) + + ( + policy, + policy_generation, + cluster, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + grpo_state, + master_config, + ) = setup(config, tokenizer, dataset, val_dataset) + + grpo_train( + policy, + policy_generation, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + task_to_env, + val_task_to_env, + logger, + checkpointer, + grpo_state, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/nemo_rl/models/generation/vllm/nemotron_json_tool_parser.py b/nemo_rl/models/generation/vllm/nemotron_json_tool_parser.py new file mode 100644 index 0000000000..eb1d433da9 --- /dev/null +++ b/nemo_rl/models/generation/vllm/nemotron_json_tool_parser.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# Source: https://huggingface.co/nvidia/NVIDIA-Nemotron-Nano-9B-v2/blob/main/nemotron_toolcall_parser_no_streaming.py + +import json +import re +from collections.abc import Sequence +from typing import Union + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, + ToolParserManager, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("nemotron_json") +class NemotronJSONToolParser(ToolParser): + + def __init__(self, tokenizer): + super().__init__(tokenizer) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = [] + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + + self.tool_call_regex = re.compile(r"(.*?)", re.DOTALL) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + else: + + try: + str_tool_calls = self.tool_call_regex.findall(model_output)[0].strip() + if not str_tool_calls.startswith("["): + str_tool_calls = "[" + str_tool_calls + if not str_tool_calls.endswith("]"): + str_tool_calls = str_tool_calls + "]" + json_tool_calls = json.loads(str_tool_calls) + tool_calls = [] + for tool_call in json_tool_calls: + try: + tool_calls.append(ToolCall( + type="function", + function=FunctionCall( + name=tool_call["name"], + arguments=json.dumps(tool_call["arguments"], ensure_ascii=False) + if isinstance(tool_call["arguments"], dict) else tool_call["arguments"], + ), + )) + except Exception: + continue + + content = model_output[:model_output.rfind(self.tool_call_start_token)] + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception(f"Error in extracting tool call from response. Response: {model_output}") + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + raise NotImplementedError("Tool calling is not supported in streaming mode!")