diff --git a/examples/configs/recipes/llm/grpo-nano-v3-2n8g-mxfp8-e2e.yaml b/examples/configs/recipes/llm/grpo-nano-v3-2n8g-mxfp8-e2e.yaml new file mode 100644 index 0000000000..249d994c28 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-nano-v3-2n8g-mxfp8-e2e.yaml @@ -0,0 +1,47 @@ +defaults: ../../grpo_math_1B.yaml +loss_fn: + use_importance_sampling_correction: true +grpo: + max_num_steps: 30 +checkpointing: + checkpoint_dir: results/grpo-nano-v3-2n8g-mxfp8-e2e +policy: + model_name: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16 + tokenizer: + name: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16 + optimizer: null + megatron_cfg: + enabled: true + bias_activation_fusion: false + tensor_model_parallel_size: 2 + context_parallel_size: 2 + expert_model_parallel_size: 8 + sequence_parallel: true + moe_router_dtype: fp32 + fp8_cfg: + enabled: true + fp8: e4m3 + fp8_recipe: mxfp8 + fp8_param: false + dtensor_cfg: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + vllm_cfg: + precision: fp8 + gpu_memory_utilization: 0.5 + fp8_cfg: + is_mx: true + dynamic_weight_quant: false +data: + max_input_seq_length: 512 +logger: + log_dir: logs/grpo-nano-v2-12b-1n8g-megatron + wandb_enabled: true + tensorboard_enabled: true + wandb: + project: nemo-rl + name: grpo-nano-v2-12b-1n8g-megatron +cluster: + gpus_per_node: 8 + num_nodes: 2 diff --git a/nemo_rl/models/generation/vllm/quantization/fp8.py b/nemo_rl/models/generation/vllm/quantization/fp8.py index 9505f42524..a72cc85395 100644 --- a/nemo_rl/models/generation/vllm/quantization/fp8.py +++ b/nemo_rl/models/generation/vllm/quantization/fp8.py @@ -19,7 +19,7 @@ import ray import torch from accelerate import init_empty_weights -from transformers import AutoConfig, AutoModel +from transformers import AutoConfig, AutoModelForCausalLM from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.linear import LinearBase from vllm.triton_utils import tl, triton @@ -31,6 +31,15 @@ "fmt": "e4m3", "quant_method": "fp8", "weight_block_size": [128, 128], + "is_mx": False, +} + + +MXFP8_BLOCK_QUANT_KWARGS = { + "activation_scheme": "dynamic", + "weight_scheme": "dynamic", + "quant_method": "fp8", + "is_mx": True, } @@ -43,6 +52,8 @@ class FP8Config: model_parallel_size: int = None kv_cache_dtype: str = "auto" use_fp8_weights: bool = True # Whether model weights are quantized to FP8 + is_mx: bool = False + dynamic_weight_quant: bool = False @dataclass() @@ -154,7 +165,7 @@ def apply_fp8_patches(self, fp8_config): def init_fp8(vllm_cfg, model_name, model_parallel_size): - config = AutoConfig.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) global global_fp8_config # Determine if we're using FP8 weights based on precision setting use_fp8_weights = vllm_cfg.get("precision") == "fp8" @@ -173,21 +184,40 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): "FP8 KV cache can only be used together with FP8 model weights." ) - global_fp8_config = FP8Config( - use_weight_pow2_scale=vllm_cfg.get("pow2_weight_scaling_factors", False), - use_activation_pow2_scale=vllm_cfg.get( + if use_fp8_weights: + is_mx = vllm_cfg["fp8_cfg"]["is_mx"] + dynamic_weight_quant = vllm_cfg["fp8_cfg"]["dynamic_weight_quant"] + else: + is_mx = False + dynamic_weight_quant = False + fp8_config_kwargs = { + "num_first_layers_in_bf16": vllm_cfg.get("num_first_layers_in_bf16", 0), + "num_last_layers_in_bf16": vllm_cfg.get("num_last_layers_in_bf16", 0), + "model_parallel_size": model_parallel_size, + "kv_cache_dtype": kv_cache_dtype, + "use_fp8_weights": use_fp8_weights, + } + if is_mx: + fp8_config_kwargs["dynamic_weight_quant"] = dynamic_weight_quant + fp8_config_kwargs["is_mx"] = True + if "pow2_weight_scaling_factors" in vllm_cfg: + print("warning: pow2_weight_scaling_factors is not used because you use is_mx=True.") + if "pow2_activation_scaling_factors" in vllm_cfg: + print("warning: pow2_activation_scaling_factors is not used because you use is_mx=True.") + else: + fp8_config_kwargs["is_mx"] = False + fp8_config_kwargs["use_weight_pow2_scale"] = vllm_cfg.get("pow2_weight_scaling_factors", False) + fp8_config_kwargs["use_activation_pow2_scale"] = vllm_cfg.get( "pow2_activation_scaling_factors", False - ), - num_first_layers_in_bf16=vllm_cfg.get("num_first_layers_in_bf16", 0), - num_last_layers_in_bf16=vllm_cfg.get("num_last_layers_in_bf16", 0), - model_parallel_size=model_parallel_size, - kv_cache_dtype=kv_cache_dtype, - use_fp8_weights=use_fp8_weights, - ) + ) + global_fp8_config = FP8Config(**fp8_config_kwargs) if vllm_cfg.get("use_deep_gemm", False): - os.environ["VLLM_USE_DEEP_GEMM"] = "1" - os.environ["VLLM_USE_DEEP_GEMM_E8M0"] = "0" + if is_mx: + print("warning: use_deep_gemm is not supported with mxfp8. It will be ignored.") + else: + os.environ["VLLM_USE_DEEP_GEMM"] = "1" + os.environ["VLLM_USE_DEEP_GEMM_E8M0"] = "1" if vllm_cfg.get("use_deep_gemm_e8m0", False) else "0" if vllm_cfg["async_engine"]: # for async engine, vllm spawns a process for each DP, so we patch @@ -201,17 +231,21 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): # create fp8 kwargs for vllm's LLM(...) num_first_layers_in_bf16 = vllm_cfg.get("num_first_layers_in_bf16", 0) num_last_layers_in_bf16 = vllm_cfg.get("num_last_layers_in_bf16", 0) - fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) + fp8_block_quant_kwargs = dict(MXFP8_BLOCK_QUANT_KWARGS) if is_mx else dict(FP8_BLOCK_QUANT_KWARGS) + if is_mx: + fp8_block_quant_kwargs["weight_scheme"] = "dynamic" if dynamic_weight_quant else "static" + else: + print(f"warning: dynamic_weight_quant={dynamic_weight_quant} will be ignored because you use is_mx=False.") if num_first_layers_in_bf16 > 0 or num_last_layers_in_bf16 > 0: with init_empty_weights(): - model = AutoModel.from_config(config) + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) param_names = [name for name, _ in model.named_parameters()] bf16_params = [] if num_first_layers_in_bf16 > 0: layers = [l for l in range(num_first_layers_in_bf16)] - bf16_params.append(_get_params_in_layers(param_names, layers)) + bf16_params.extend(_get_params_in_layers(param_names, layers)) if num_last_layers_in_bf16 > 0: layers = [ @@ -221,17 +255,21 @@ def init_fp8(vllm_cfg, model_name, model_parallel_size): config.num_hidden_layers, ) ] - bf16_params.append(_get_params_in_layers(param_names, layers)) + bf16_params.extend(_get_params_in_layers(param_names, layers)) fp8_block_quant_kwargs["ignored_layers"] = bf16_params quantization_ignored_layer_kws = vllm_cfg.get("quantization_ignored_layer_kws", []) if len(quantization_ignored_layer_kws): with init_empty_weights(): - model = AutoModel.from_config(config) + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) param_names = [ f"model.{name}".removesuffix(".weight") for name, _ in model.named_parameters() ] + param_names = [ + name.replace("model.backbone.", "backbone.") + for name in param_names + ] ignored_layers = [ n for n in param_names @@ -259,8 +297,8 @@ def is_fp8_model(vllm_config): if hasattr(vllm_config, "quant_config") and isinstance( vllm_config.quant_config, Fp8Config ): - assert vllm_config.quant_config.weight_block_size is not None, ( - "Only block scaling is currently supported in NeMo-RL!" + assert vllm_config.quant_config.weight_block_size is not None or vllm_config.quant_config.is_mx, ( + "Only block scaling or mxfp8 is currently supported in NeMo-RL!" ) return True @@ -293,7 +331,9 @@ def _get_params_in_layers(param_names, layers): ): # Convert the param name into vllm's module name # Vllm wraps the model with an extra 'model' - params.append(f"model.{name}".removesuffix(".weight")) + _name = f"model.{name}".removesuffix(".weight") + _name = _name.replace("model.backbone.", "backbone.") + params.append(_name) return params @@ -311,6 +351,13 @@ def _get_module_from_param_name(model, name: str): } if module_path[-1] in reversed_mapping.keys(): module_path[-1] = reversed_mapping[module_path[-1]] + if hasattr(model, "hf_to_vllm_mapper") and hasattr(model.hf_to_vllm_mapper, "orig_to_new_prefix"): + if module_path[0] in model.hf_to_vllm_mapper.orig_to_new_prefix: + module_path[0] = model.hf_to_vllm_mapper.orig_to_new_prefix[module_path[0]] + if hasattr(model, "hf_to_vllm_mapper") and hasattr(model.hf_to_vllm_mapper, "orig_to_new_substr"): + for i in range(len(module_path)): + if module_path[i] in model.hf_to_vllm_mapper.orig_to_new_substr: + module_path[i] = model.hf_to_vllm_mapper.orig_to_new_substr[module_path[i]] current_module = model try: @@ -348,6 +395,16 @@ def _is_fp8_weight(name, model): def load_weights(weights, model_runner): + global global_fp8_config + + + if global_fp8_config.is_mx and global_fp8_config.dynamic_weight_quant: + # fall back to the default way of loading high-precision weights + model_runner.model.load_weights(weights) + # synchronize to ensure the weights are loaded + torch.cuda.synchronize() + return + weights_quantized = [] model = model_runner.model @@ -356,15 +413,29 @@ def load_weights(weights, model_runner): weights_quantized.append((k, v)) continue # Cast the weight into fp8 and its scale factor - param_lp, param_scale = cast_tensor_to_fp8_blockwise( - v.to(torch.float), - weight_block_size=FP8_BLOCK_QUANT_KWARGS["weight_block_size"], - ) + if global_fp8_config.is_mx: + param_lp, param_scale = cast_tensor_to_mxfp8(v) + else: + param_lp, param_scale = cast_tensor_to_fp8_blockwise( + v.to(torch.float), + weight_block_size=FP8_BLOCK_QUANT_KWARGS["weight_block_size"], + ) param_scale = torch.squeeze(param_scale, dim=-1) - weights_quantized.append([k, param_lp]) - weights_quantized.append([k + "_scale_inv", param_scale]) + if global_fp8_config.is_mx: + weights_quantized.append([k, param_lp]) + weights_quantized.append([k + "_scale", param_scale]) + else: + weights_quantized.append([k, param_lp]) + weights_quantized.append([k + "_scale_inv", param_scale]) # Finally load the weights into vllm model.load_weights(weights_quantized) + # synchronize to ensure the weights are loaded + torch.cuda.synchronize() + + +def cast_tensor_to_mxfp8(data_hp): + from vllm.model_executor.layers.quantization.utils.mxfp8_utils import mxfp8_e4m3_quantize + return mxfp8_e4m3_quantize(data_hp) def cast_tensor_to_fp8_blockwise( @@ -499,22 +570,39 @@ def process_weights_after_loading(self, layer) -> None: from vllm.model_executor.layers.quantization.utils.fp8_utils import ( process_fp8_weight_block_strategy, ) + from vllm.model_executor.parameter import ModelWeightParameter + from vllm.model_executor.layers.quantization.utils.quant_utils import swizzle_blockscale + from vllm.model_executor.layers.quantization.utils.mxfp8_utils import mxfp8_e4m3_quantize - assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized + assert self.quant_config.is_mx and self.quant_config.is_checkpoint_fp8_serialized and not self.use_marlin, "This branch currently only supports mxfp8, not block-wise fp8" assert self.quant_config.activation_scheme == "dynamic" - weight_scale = layer.weight_scale_inv - weight, weight_scale = process_fp8_weight_block_strategy(layer.weight, weight_scale) - layer.weight.data = weight.data - if hasattr(layer, "weight_scale"): - # Not the first time to call this function, just need to update the data - layer.weight_scale.copy_(weight_scale.data) + layer.input_scale = None + + if self.quant_config.is_mx: + if self.quant_config.weight_scheme == "dynamic": + weight, weight_scale = mxfp8_e4m3_quantize(layer.weight.data.to(torch.bfloat16)) + else: + weight = layer.weight + weight_scale = layer.weight_scale + if not hasattr(layer, "weight_for_apply"): + layer.weight_for_apply = torch.nn.Parameter(weight.data, requires_grad=False) + layer.weight_scale_for_apply = torch.nn.Parameter(swizzle_blockscale(weight_scale), requires_grad=False) + else: + layer.weight_for_apply.copy_(weight) + layer.weight_scale_for_apply.copy_(swizzle_blockscale(weight_scale)) else: - # The first time to call this function, create a new parameter and update the tp status - layer.weight_scale = torch.nn.Parameter(weight_scale.data, requires_grad=False) - layer.update_param_tp_status() + weight, weight_scale = process_fp8_weight_block_strategy(layer.weight, weight_scale) + layer.weight.data = weight.data + if hasattr(layer, "weight_scale"): + # Not the first time to call this function, just need to update the data + layer.weight_scale.copy_(weight_scale.data) + else: + # The first time to call this function, create a new parameter and update the tp status + layer.weight_scale = torch.nn.Parameter(weight_scale.data, requires_grad=False) + layer.update_param_tp_status() - maybe_post_process_fp8_weight_block(layer) + maybe_post_process_fp8_weight_block(layer) def process_weights_after_loading_moe(self, layer) -> None: @@ -534,12 +622,114 @@ def process_weights_after_loading_moe(self, layer) -> None: from vllm.utils.deep_gemm import ( is_deep_gemm_e8m0_used, ) + from vllm.model_executor.layers.quantization.utils.mxfp8_utils import mxfp8_e4m3_quantize, dequant_mxfp8_to_bf16 + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() - assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized assert self.quant_config.activation_scheme == "dynamic" + if self.quant_config.is_mx: + def maybe_quantize(weight, weight_scale_attr): + """Quantize dynamically or use prequantized weights.""" + if self.quant_config.weight_scheme == "dynamic": + return mxfp8_e4m3_quantize(weight.data.to(torch.bfloat16)) + else: + return weight, weight_scale_attr + + if hasattr(layer, "layer_name"): # FusedMoE layer has this attribute + layer_identifier = str(layer.layer_name) + elif hasattr(layer, "prefix"): # Linear layer has this attribute + layer_identifier = str(layer.prefix) + else: # Otherwise, use the string representation of the layer + layer_identifier = str(type(layer)) + + if self.flashinfer_moe_backend is not None: + from vllm.model_executor.layers.quantization.fp8 import round_up, pad_to + from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + FlashinferMoeBackend, + rotate_flashinfer_fp8_moe_weights, + ) + + # This is a hack for mxfp8 only + assert ( + self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM + ) + + layer.intermediate_size_per_partition = round_up( + layer.intermediate_size_per_partition, 128 + ) + + tmp_w13_weight, tmp_w13_weight_scale = maybe_quantize( + layer.w13_weight, layer.w13_weight_scale + ) + + tmp_w13_weight = pad_to( + tmp_w13_weight, 1, layer.intermediate_size_per_partition + ).contiguous() + tmp_w13_weight_scale = pad_to( + tmp_w13_weight_scale, 1, layer.intermediate_size_per_partition + ).to(dtype=torch.uint8).contiguous() + tmp_w2_weight = pad_to( + tmp_w2_weight, 2, layer.intermediate_size_per_partition + ).contiguous() + tmp_w2_weight_scale = pad_to( + tmp_w2_weight_scale, 2, layer.intermediate_size_per_partition // 32 + ).to(dtype=torch.uint8).contiguous() + + gemm1_w, gemm2_w, gemm1_s, gemm2_s = ( + rotate_flashinfer_fp8_moe_weights( + tmp_w13_weight, + tmp_w2_weight, + tmp_w13_weight_scale, + tmp_w2_weight_scale, + ) + ) + if not hasattr(layer, "w13_weight_shuffled"): + layer.w13_weight_shuffled = torch.nn.Parameter(gemm1_w, requires_grad=False) + layer.w2_weight_shuffled = torch.nn.Parameter(gemm2_w, requires_grad=False) + layer.w13_scales_shuffled = torch.nn.Parameter(gemm1_s, requires_grad=False) + layer.w2_scales_shuffled = torch.nn.Parameter(gemm2_s, requires_grad=False) + layer.w13_weight.data = torch.narrow(layer.w13_weight_shuffled.data, 1, 0, layer.w13_weight.shape[1]) + layer.w2_weight.data = torch.narrow(layer.w2_weight_shuffled.data, 2, 0, layer.w2_weight.shape[2]) + else: + layer.w13_weight_shuffled.copy_(gemm1_w) + layer.w2_weight_shuffled.copy_(gemm2_w) + layer.w13_scales_shuffled.copy_(gemm1_s) + layer.w2_scales_shuffled.copy_(gemm2_s) + else: + + # ------------------------- + # w13 processing + # ------------------------- + w13_q, w13_scale = maybe_quantize(layer.w13_weight, layer.w13_weight_scale) + + dq_w13 = dequant_mxfp8_to_bf16(w13_q, w13_scale).contiguous() + layer.w13_weight.copy_(dq_w13) + + # ------------------------- + # w2 processing + # ------------------------- + w2_q, w2_scale_full = maybe_quantize(layer.w2_weight, layer.w2_weight_scale) + + if hasattr(layer, "layer_name"): # FusedMoE layer has this attribute + layer_identifier = str(layer.layer_name) + elif hasattr(layer, "prefix"): # Linear layer has this attribute + layer_identifier = str(layer.prefix) + else: # Otherwise, use the string representation of the layer + layer_identifier = str(type(layer)) + + # Select expert block + blk = math.ceil(w2_q.shape[-1] / 32) + start = layer.ep_rank * blk + end = (layer.ep_rank + 1) * blk + w2_scale = w2_scale_full[..., start:end] + + dq_w2 = dequant_mxfp8_to_bf16(w2_q, w2_scale).contiguous() + layer.w2_weight.copy_(dq_w2) + + return + if self.flashinfer_moe_backend is not None: w13_weight = swap_w13_to_w31(layer.w13_weight.data) w13_weight_scale_inv = swap_w13_to_w31(layer.w13_weight_scale_inv.data) diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 5d239fd902..8f49cbee44 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -13,7 +13,7 @@ # limitations under the License. import gc import traceback -from typing import Any +from typing import Any, Optional, Tuple import torch import zmq @@ -123,11 +123,11 @@ def _maybe_process_fp8_kv_cache(self) -> None: ) @wrap_with_nvtx_name("vllm_internal_worker_extension/update_weights_via_ipc_zmq") - def update_weights_via_ipc_zmq(self) -> bool: + def update_weights_via_ipc_zmq(self) -> Tuple[bool, Optional[Exception]]: """Receive and update model weights via ZMQ IPC socket. Returns: - bool: True if weights were successfully updated. + Tuple[bool, Exception]: (True, None) if weights were successfully updated, otherwise False and the exception. """ buffer = None weights = None @@ -194,23 +194,31 @@ def update_weights_via_ipc_zmq(self) -> bool: buffer = None self.zmq_socket.send(IPCProtocol.ACK.value.encode()) + torch.cuda.synchronize() # Process weights after loading for FP8 KV cache + from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading, + ) + + process_weights_after_loading( + self.model_runner.model, self.model_config, self.device + ) self._maybe_process_fp8_kv_cache() gc.collect() torch.cuda.empty_cache() - return True + return True, None except Exception as e: print( f"Error in VllmInternalWorkerExtension.update_weights_via_ipc_zmq: {e}.\n" f"{traceback.format_exc()}" ) - return False + return False, e @wrap_with_nvtx_name( "vllm_internal_worker_extension/update_weights_from_collective" ) - def update_weights_from_collective(self) -> bool: + def update_weights_from_collective(self) -> Tuple[bool, Optional[Exception]]: """Update the model weights from collective communication.""" assert self.state_dict_info is not None, ( "state_dict_info is not prepared. " @@ -252,9 +260,11 @@ def _load_model_weights(weights, model_runner): print( f"Error in VllmInternalWorkerExtension.update_weights_from_collective: {e}" ) - return False + return False, e - return True + gc.collect() + torch.cuda.empty_cache() + return True, None def cleanup(self) -> None: """Shutdown and cleanup resources.""" diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 9238533cd2..3f9f3a77cb 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -745,9 +745,9 @@ def update_weights_via_ipc_zmq(self) -> bool: ) worker_result = result_or_coro[0] - if not worker_result: + if not worker_result[0]: print( - f"Error: Worker failed to update weights. Result: {worker_result}" + f"Error: Worker failed to update weights. Result: {worker_result[1]}" ) return False return True @@ -776,9 +776,9 @@ def update_weights_from_collective(self) -> bool: ) worker_result = result_or_coro[0] - if not worker_result: + if not worker_result[0]: print( - f"Error: Worker failed to update weights. Result: {worker_result}" + f"Error: Worker failed to update weights. Result: {worker_result[1]}" ) return False return True diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 0e4ea5cdeb..a83144bffb 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -1012,9 +1012,9 @@ async def update_weights_via_ipc_zmq_async( worker_result = worker_results[0] - if not worker_result: + if not worker_result[0]: print( - f"Error: Worker failed to update weights. Result: {worker_result}" + f"Error: Worker failed to update weights. Result: {worker_result[1]}" ) return False return True @@ -1048,9 +1048,9 @@ async def update_weights_from_collective_async(self) -> bool: worker_result = worker_results[0] - if not worker_result: + if not worker_result[0]: print( - f"Error: Worker failed to update weights. Result: {worker_result}" + f"Error: Worker failed to update weights. Result: {worker_result[1]}" ) return False return True diff --git a/tests/test_suites/llm/grpo-nano-v3-2n8g-mxfp8-e2e.sh b/tests/test_suites/llm/grpo-nano-v3-2n8g-mxfp8-e2e.sh new file mode 100644 index 0000000000..1718cfaf23 --- /dev/null +++ b/tests/test_suites/llm/grpo-nano-v3-2n8g-mxfp8-e2e.sh @@ -0,0 +1,43 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=2 +STEPS_PER_RUN=30 +MAX_STEPS=30 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=60 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'median(data["train/token_mult_prob_error"]) < 1.05' \ + 'data["train/token_mult_prob_error"]["30"] < 1.05' \ + 'data["train/reward"]["30"] > 0.4' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi