diff --git a/.gitmodules b/.gitmodules index 796f7b17c3..81d066b8b0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,7 @@ [submodule "3rdparty/Megatron-LM"] path = 3rdparty/Megatron-LM-workspace/Megatron-LM - url = https://github.com/terrykong/Megatron-LM.git - branch = yuya/nemo-rl-use-dev + url = https://github.com/yaoyu-33/Megatron-LM.git + branch = main shallow = true [submodule "3rdparty/Megatron-Bridge"] path = 3rdparty/Megatron-Bridge-workspace/Megatron-Bridge diff --git a/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge b/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge index 1e9a459b43..a3fc5d57e2 160000 --- a/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge +++ b/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge @@ -1 +1 @@ -Subproject commit 1e9a459b43aa1f62ca1356e554d2b0196ebdd546 +Subproject commit a3fc5d57e26fdb0d998b768bf870425a8581c2fb diff --git a/3rdparty/Megatron-Bridge-workspace/setup.py b/3rdparty/Megatron-Bridge-workspace/setup.py index 9aec2e6481..9ff7cb51e1 100644 --- a/3rdparty/Megatron-Bridge-workspace/setup.py +++ b/3rdparty/Megatron-Bridge-workspace/setup.py @@ -26,7 +26,8 @@ bridge_package_name = "megatron.bridge" CACHED_DEPENDENCIES = [ - "transformers>=4.57.1", + "accelerate", + "transformers==4.57.1", "datasets", "omegaconf>=2.3.0", "tensorboard>=2.19.0", @@ -40,7 +41,7 @@ "hydra-core>1.3,<=1.3.2", "megatron-core[dev,mlm]>=0.15.0a0,<0.17.0", "qwen-vl-utils", - "transformer-engine[pytorch]>=2.9.0a0,<2.10.0", + "transformer-engine[pytorch]>=2.10.0a0,<2.12.0", "mamba-ssm", "nvidia-resiliency-ext", "causal-conv1d", diff --git a/3rdparty/Megatron-LM-workspace/Megatron-LM b/3rdparty/Megatron-LM-workspace/Megatron-LM index b73ae5cdab..11dcbaca31 160000 --- a/3rdparty/Megatron-LM-workspace/Megatron-LM +++ b/3rdparty/Megatron-LM-workspace/Megatron-LM @@ -1 +1 @@ -Subproject commit b73ae5cdab9d409fcface2b2f3c375710abe6911 +Subproject commit 11dcbaca317133cc5c77c8bc4f54ed71d3b5d656 diff --git a/3rdparty/Megatron-LM-workspace/setup.py b/3rdparty/Megatron-LM-workspace/setup.py index 0a088b393e..2874fe10b7 100644 --- a/3rdparty/Megatron-LM-workspace/setup.py +++ b/3rdparty/Megatron-LM-workspace/setup.py @@ -44,30 +44,30 @@ CACHED_DEPENDENCIES = [ # Default dependencies from pyproject.toml "torch", - "numpy<2.0.0", + "numpy", "packaging>=24.2", # Dev dependencies from pyproject.toml - "nvidia-modelopt[torch]>=0.33.0a0,<0.34.0; sys_platform != 'darwin'", - "transformer-engine[pytorch]>=2.9.0a0,<2.10.0", - "nvidia-resiliency-ext>=0.4.0a0,<0.5.0", + "nvidia-modelopt[torch]; sys_platform != 'darwin'", + "transformer-engine[pytorch,core_cu13]>=2.9.0a0,<2.12.0", + "nvidia-resiliency-ext", "tqdm", "einops~=0.8", "tensorstore~=0.1,!=0.1.46,!=0.1.72", "nvtx~=0.2", "multi-storage-client~=0.27", "opentelemetry-api~=1.33.1", - "setuptools<80.0.0", "mamba-ssm~=2.2", "causal-conv1d~=1.5", "nv-grouped-gemm~=1.1", "megatron-energon[av_decode]~=6.0", - "av<16.0.0", - "flashinfer-python", + "av", + "flashinfer-python~=0.5.0", "wget", "onnxscript", - "flash-linear-attention~=0.3.2", # VCS dependency - must match pyproject.toml [tool.uv.sources] "emerging_optimizers @ git+https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git@v0.1.0", + "datasets", + "fastapi~=0.50", ] diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index 1a14b8ce64..a8ab972f07 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -143,14 +143,17 @@ policy: top_p: 1.0 top_k: null mcore_generation_config: - buffer_size_gb: 20 # Total GPU memory (in GB) allocated for KV cache buffers + buffer_size_gb: 25 # Total GPU memory (in GB) allocated for KV cache buffers buffer_guaranteed_fraction: 0.1 # Fraction of buffer reserved for guaranteed active requests - num_cuda_graphs: 16 # Number of CUDA graphs to pre-compile for different batch sizes + num_cuda_graphs: 6 # Number of CUDA graphs to pre-compile for different batch sizes block_size_tokens: 256 # Size of each KV cache block in tokens (affects memory granularity) use_cuda_graphs_for_non_decode_steps: true # Enable CUDA graphs for prefill/context processing - enable_chunked_prefill: true # Split long prefills into chunks for better memory management - unified_memory_level: 0 # Unified memory usage level (0=disabled, higher values enable more aggressive paging) + unified_memory_level: 1 # Unified memory usage level (0=disabled, 1+=enables unified memory with static tensor addresses) max_tokens: 16384 # Maximum number of tokens to use in a single step. Analogous to vllm's max_num_batched_tokens + reset_cuda_graphs: true + offload_kv_cache_during_training: true # Move KV cache to CPU during training + enable_cuda_graph: true + enable_chunked_prefill: false vllm_cfg: tensor_parallel_size: 1 @@ -178,8 +181,8 @@ logger: swanlab_enabled: false # Disable SwanLab logging monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard wandb: - project: "grpo-dev" - name: "sj_megatron_1B" + project: "qwen_30b_final" + name: "none" swanlab: project: "grpo-dev" name: "sj_megatron_1B" diff --git a/examples/configs/grpo_math_qwen30ba3b_megatron.yaml b/examples/configs/grpo_math_qwen30ba3b_megatron.yaml index 37616e32b0..8856892772 100644 --- a/examples/configs/grpo_math_qwen30ba3b_megatron.yaml +++ b/examples/configs/grpo_math_qwen30ba3b_megatron.yaml @@ -2,18 +2,18 @@ defaults: "grpo_math_1B_megatron.yaml" grpo: - num_prompts_per_step: 64 - num_generations_per_prompt: 32 + num_prompts_per_step: 16 + num_generations_per_prompt: 8 policy: model_name: "Qwen/Qwen3-30B-A3B" tokenizer: name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default - train_global_batch_size: 512 + train_global_batch_size: 64 train_micro_batch_size: 1 generation_batch_size: 32 # Only used when generating using HF backend logprob_batch_size: 4 - max_total_sequence_length: 4096 + max_total_sequence_length: 1024 precision: "bfloat16" dtensor_cfg: @@ -68,7 +68,7 @@ policy: stop_token_ids: null stop_strings: null vllm_cfg: - tensor_parallel_size: 4 + tensor_parallel_size: 16 gpu_memory_utilization: 0.7 enforce_eager: false max_model_len: ${policy.max_total_sequence_length} diff --git a/mcore_dp_inference_cooredinator.md b/mcore_dp_inference_cooredinator.md new file mode 100644 index 0000000000..614bdc98f5 --- /dev/null +++ b/mcore_dp_inference_cooredinator.md @@ -0,0 +1,740 @@ +## MCORE NEMO RL INTEGRATION + +### RUNNING THE APPLICATION + +**QWEN 1.5B** + +wandb: https://wandb.ai/shanmugamr/mcore_vllm_latest?nw=nwusershanmugamr + +``` +srun --gpus-per-node 8 --time 04:00:00 --account coreai_dlalgo_llm --job-name coreai_dlalgo_llm:inference --partition interactive --container-image nvcr.io/nvidian/nemo-rl:nightly --container-mounts /lustre/fsw/portfolios/coreai/users/shanmugamr/RL:/opt/nemo-rl,/lustre/fsw/portfolios/coreai/users/shanmugamr/:/lustre/fsw/portfolios/coreai/users/shanmugamr/ --account coreai_dlalgo_llm -N 1 -J coreai_dlalgo_llm-multimodal:debug --gpus-per-node=8 --no-container-mount-home -p interactive --pty bash + +export HF_HOME=/lustre/fsw/portfolios/coreai/users/shanmugamr/rl_data_new/ +export TORCH_CUDA_ARCH_LIST='9.0 10.0' +export HF_TOKEN= +wandb login +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +NRL_FORCE_REBUILD_VENVS=true uv run examples/run_grpo_math.py --config examples/configs/grpo_math_1B_megatron.yaml policy.generation.backend=megatron grpo.max_num_steps=50 logger.wandb_enabled=True logger.wandb.name=mcore_tp2_dp2 logger.wandb.project=mcore_vllm_latest cluster.gpus_per_node=4 policy.megatron_cfg.tensor_model_parallel_size=2 +``` + +**QWEN 30B** + +wandb : https://wandb.ai/shanmugamr/qwen_30b_final/workspace?nw=nwusershanmugamr + +``` +cd /lustre/fsw/portfolios/coreai/users/shanmugamr/RL + +NUM_ACTOR_NODES=2 \ +CONTAINER=/lustre/fsw/portfolios/coreai/users/shanmugamr/RL/rl_nightly.sqsh \ +MOUNTS="/lustre/fsw/portfolios/coreai/users/shanmugamr/RL:/opt/nemo-rl,/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_llm:/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_llm,/lustre/fsw/portfolios/coreai/users/shanmugamr:/lustre/fsw/portfolios/coreai/users/shanmugamr" \ +sbatch \ + --nodes=2 \ + --account=coreai_dlalgo_llm \ + --job-name=coreai_dlalgo_llm:inference \ + --partition=interactive \ + --time=4:0:0 \ + --gres=gpu:8 \ + ray.sub + +cd /opt/nemo-rl + +export HF_HOME=/lustre/fsw/portfolios/coreai/users/shanmugamr/rl_data_new/ +export TORCH_CUDA_ARCH_LIST='9.0 10.0' +export HF_TOKEN= +wandb login +export CUDA_DEVICE_MAX_CONNECTIONS=11 + + +uv run examples/run_grpo_math.py --config examples/configs/grpo_math_qwen30ba3b_megatron.yaml policy.generation.backend=megatron grpo.max_num_steps=10 cluster.num_nodes=2 logger.wandb.name=mcore_uvm1 logger.wandb_enabled=True +``` + + +### WORKING CONFIGURATION FOR QWEN 30B +``` +uvm: 1 +reset_cuda_graphs: true +offload_kv_cache: true + +grpo: + num_prompts_per_step: 16 + num_generations_per_prompt: 8 +policy: + train_global_batch_size: 64 + max_total_sequence_length: 1024 +``` + +### FLAKY STUFF +* Not sure why with uvm 1 alone it doesnt work. (You need not set offload kv cache) This is very inefficient and this does cuda warmup every iteration (reset cuda graphs deletes cuda graphs and so maybe that clears the memory , but it also recomputes cuda graphs). so this flaky +* cuda_graph_scope if i set it to `full_iteration` I see cuda graphs used properly (I added assert in `transformer_block.py ___call__ function` ). Its significantly faster (1 min) vs if I leave it to None it gets called at the `transformer_layer.py __call_ function` (1 min 46 seconds) +* Sometimes I get this engine suspended error +```========================= Step 9/10 ========================= +▶ Preparing batch... +▶ Generating responses for batch of size 128... +(MegatronPolicyWorker[rank=8] pid=4154549) GPU Memory before optimizer offload: 18.61GB allocated, 63.65GB reserved +(MegatronPolicyWorker[rank=14] pid=4155351) [Rank 14] Suspended inference engine [repeated 15x across cluster] +(MegatronPolicyWorker[rank=0] pid=1432916, ip=10.65.30.3) GPU Memory after optimizer offload: 0.84GB allocated, 28.07GB reserved +(MegatronPolicyWorker[rank=0] pid=1432916, ip=10.65.30.3) GPU Memory after refit complete: 0.84GB allocated, 28.07GB reserved +(MegatronPolicyWorker[rank=0] pid=1432916, ip=10.65.30.3) [INFO] Restoring KV cache (25.00 GB) to GPU +(MegatronPolicyWorker[rank=9] pid=4154537) [Rank 9] Resumed inference engine +(MegatronPolicyWorker[rank=9] pid=4154537) GPU : 0, participating in engine loop (no data to submit) +(MegatronPolicyWorker[rank=9] pid=4154537) [Rank 9] Participating in engine loop only (not submitting requests) +(MegatronPolicyWorker[rank=3] pid=1432934, ip=10.65.30.3) GPU Memory before optimizer offload: 18.61GB allocated, 63.65GB reserved [repeated 15x across cluster] +(MegatronPolicyWorker[rank=0] pid=1432916, ip=10.65.30.3) GPU : 0, input_ids: torch.Size([128, 158]) +(MegatronPolicyWorker[rank=0] pid=1432916, ip=10.65.30.3) [Rank 0] Submitting 128 requests to coordinator +(MegatronPolicyWorker[rank=9] pid=4154537) ERROR:megatron.core.utils:utils.py:2296: Exception in async function run_engine_with_coordinator: 7782 +(MegatronPolicyWorker[rank=9] pid=4154537) Traceback (most recent call last): +(MegatronPolicyWorker[rank=9] pid=4154537) File "/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_llm/users/shanmugamr/RL/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/utils.py", line 2294, in wrapper +(MegatronPolicyWorker[rank=9] pid=4154537) return await fn(*args, **kwargs) +(MegatronPolicyWorker[rank=9] pid=4154537) ^^^^^^^^^^^^^^^^^^^^^^^^^ +(MegatronPolicyWorker[rank=9] pid=4154537) File "/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_llm/users/shanmugamr/RL/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/engines/dynamic_engine.py", line 1660, in run_engine_with_coordinator +(MegatronPolicyWorker[rank=9] pid=4154537) await self.async_step() +(MegatronPolicyWorker[rank=9] pid=4154537) File "/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_llm/users/shanmugamr/RL/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/engines/dynamic_engine.py", line 1356, in async_step +(MegatronPolicyWorker[rank=9] pid=4154537) last_step_data = await self.async_forward() +(MegatronPolicyWorker[rank=9] pid=4154537) ^^^^^^^^^^^^^^^^^^^^^^^^^^ +(MegatronPolicyWorker[rank=9] pid=4154537) File "/lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_llm/users/shanmugamr/RL/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/inference/engines/dynamic_engine.py", line 1122, in async_forward +(MegatronPolicyWorker[rank=9] pid=4154537) raise EngineSuspendedError(self.step_count) +(MegatronPolicyWorker[rank=9] pid=4154537) megatron.core.inference.engines.dynamic_engine.EngineSuspendedError: 7782 +``` +* Sometimes I see timeout issues (Maybe due to improper synchronization) + +### CHANGES TO BE DONE IN OTHER LIBRARIES +1. Megatron bridge - remove import statement +``` +--- a/src/megatron/bridge/training/setup.py ++++ b/src/megatron/bridge/training/setup.py +@@ -22,7 +22,7 @@ import torch +-from megatron.core.jit import disable_jit_fuser +``` + +2. Unified memory (Since we run coordinator in a thread) +``` +FILE: /megatron/core/inference/unified_memory.py +index bc0a9c7..85c3517 100644 +--- a/megatron/core/inference/unified_memory.py ++++ b/megatron/core/inference/unified_memory.py +@@ -69,6 +69,12 @@ def _compile_timeout(timeout_s: int): + "Please clean up your stale cache and try again." + ) + ++ # Signal-based timeout only works in the main thread. ++ # In non-main threads (e.g., Ray actors), skip the timeout mechanism. ++ if threading.current_thread() is not threading.main_thread(): ++ yield ++ return + +``` + +3. Hugging face tokenizer to handle out of vocab tokens (since we pad vocab size) +``` +FILE: megatron/core/tokenizers/text/libraries/huggingface_tokenizer.py + """Converts list of ids to text.""" + tokens = self.ids_to_tokens(ids) + if remove_special_tokens: +- tokens_clean = [t for t in tokens if t not in self.tokenizer.all_special_tokens] ++ tokens_clean = [t for t in tokens if t is not None and t not in self.tokenizer.all_special_tokens] + else: +- tokens_clean = tokens ++ tokens_clean = [t for t in tokens if t is not None] + text = self.tokens_to_text(tokens_clean) + return text +` +``` + +4. Distributed data parallel issue (Fix) +``` +@@ -369,12 +369,15 @@ class DistributedDataParallel(_BaseDataParallel): + Skip synchronous param all-gather if `param_sync` is False. + """ + assert self.use_forward_hook ++ for module, handle in list(self.remove_forward_pre_hook_handles.items()): ++ handle.remove() ++ self.remove_forward_pre_hook_handles.clear() + +- for module in self.module.modules(): +- assert self.remove_forward_pre_hook_handles[module] is not None +- self.remove_forward_pre_hook_handles[module].remove() +- del self.remove_forward_pre_hook_handles[module] +- assert len(self.remove_forward_pre_hook_handles) == 0 + + # Force synchronize parameters. + if param_sync: +``` +5. Sometimes I see a timeout issue (Timeout issue. Need to add proper synchronization I think) + +I got this issue in step 6 : +``` +========================= Step 6/50 ========================= +▶ Preparing batch... +▶ Generating responses for batch of size 128... +(MegatronPolicyWorker[rank=11] pid=3098847) GPU Memory before optimizer offload: 18.42GB allocated, 67.79GB reserved +(MegatronPolicyWorker[rank=15] pid=3098871) [Rank 15] paused inference engine [repeated 15x across cluster] +(MegatronPolicyWorker[rank=1] pid=1892466, ip=10.65.5.217) GPU Memory after optimizer offload: 0.65GB allocated, 32.23GB reserved +(MegatronPolicyWorker[rank=1] pid=1892466, ip=10.65.5.217) GPU Memory after refit complete: 0.65GB allocated, 32.23GB reserved +(MegatronPolicyWorker[rank=0] pid=1892406, ip=10.65.5.217) [INFO] Restoring KV cache (30.00 GB) to GPU +(MegatronPolicyWorker[rank=0] pid=1892406, ip=10.65.5.217) [Rank 0] Resumed inference engine +(MegatronPolicyWorker[rank=0] pid=1892406, ip=10.65.5.217) GPU : 0, input_ids: torch.Size([128, 495]) +(MegatronPolicyWorker[rank=0] pid=1892406, ip=10.65.5.217) [Rank 0] Submitting 128 requests to coordinator +(MegatronPolicyWorker[rank=3] pid=1892419, ip=10.65.5.217) GPU Memory before optimizer offload: 18.43GB allocated, 67.83GB reserved [repeated 15x across cluster] +(MegatronPolicyWorker[rank=10] pid=3098492) GPU Memory after optimizer offload: 0.65GB allocated, 32.22GB reserved [repeated 15x across cluster] +(MegatronPolicyWorker[rank=10] pid=3098492) GPU Memory after refit complete: 0.65GB allocated, 32.22GB reserved [repeated 15x across cluster] +(MegatronPolicyWorker[rank=3] pid=1892419, ip=10.65.5.217) GPU : 0, participating in engine loop (no data to submit) +(MegatronPolicyWorker[rank=3] pid=1892419, ip=10.65.5.217) [Rank 3] Participating in engine loop only (not submitting requests) +(MegatronPolicyWorker[rank=9] pid=3098846) [rank9]:[E129 16:38:47.714693059 ProcessGroupNCCL.cpp:683] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=49387, OpType=_ALLGATHER_BASE, NumelIn=442368, NumelOut=884736, Timeout(ms)=600000) ran for 600002 milliseconds before timing out. +(MegatronPolicyWorker[rank=9] pid=3098846) [rank9]:[E129 16:38:47.714849575 ProcessGroupNCCL.cpp:2241] [PG ID 5 PG GUID 37(TENSOR_MODEL_PARALLEL_GROUP) Rank 1] failure detected by watchdog at work sequence id: 49387 PG status: last enqueued work: 49406, last completed work: 49386 +(MegatronPolicyWorker[rank=9] pid=3098846) [rank9]:[E129 16:38:47.714860440 ProcessGroupNCCL.cpp:730] Stack trace of the failed collective not found, potentially because FlightRecorder is disabled. You can enable it by setting TORCH_NCCL_TRACE_BUFFER_SIZE to a non-zero value. +(MegatronPolicyWorker[rank=9] pid=3098846) [rank9]:[E129 16:38:47.714884178 ProcessGroupNCCL.cpp:2573] [PG ID 5 PG GUID 37(TENSOR_MODEL_PARALLEL_GROUP) Rank 1] First PG on this rank to signal dumping. +(MegatronPolicyWorker[rank=9] pid=3098846) NCCL version 2.27.5+cuda12.9 +(MegatronPolicyWorker[rank=15] pid=3098871) [Rank 15] Resumed inference engine [repeated 15x across cluster] +(MegatronPolicyWorker[rank=15] pid=3098871) GPU : 0, participating in engine loop (no data to submit) [repeated 14x across cluster] +(MegatronPolicyWorker[rank=15] pid=3098871) [Rank 15] Participating in engine loop only (not submitting requests) [repeated 14x across cluster] +(MegatronPolicyWorker[rank=9] pid=3098846) [rank9]:[E129 16:38:47.769817695 ProcessGroupNCCL.cpp:683] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=52229, OpType=ALLTOALL_BASE, NumelIn=56623104, NumelOut=56623104, Timeout(ms)=600000) ran for 600059 milliseconds before timing out. +(MegatronPolicyWorker[rank=9] pid=3098846) [rank9]:[E129 16:38:47.769886452 ProcessGroupNCCL.cpp:2241] [PG ID 12 PG GUID 100(EXPERT_MODEL_PARALLEL_GROUP) Rank 1] failure detected by watchdog at work sequence id: 52229 PG status: last enqueued work: 52258, last completed work: 52228 +(MegatronPolicyWorker[rank=9] pid=3098846) [rank9]:[E129 16:38:47.769926064 ProcessGroupNCCL.cpp:730] Stack trace of the failed collective not found, potentially because FlightRecorder is disabled. You can enable it by setting TORCH_NCCL_TRACE_BUFFER_SIZE to a non-zero value. +(MegatronPolicyWorker[rank=8] pid=3098870) [rank8]:[E129 16:38:47.801786726 ProcessGroupNCCL.cpp:1858] [PG ID 0 PG GUID 0(default_pg) Rank 8] Received a dump signal due to a collective timeout from this local rank and we will try our best to dump the debug info. Last enqueued NCCL work: 62, last completed NCCL work: 62.This is most likely caused by incorrect usages of collectives, e.g., wrong sizes used across ranks, the order of collectives is not same for all ranks or the scheduled collective, for some reason, didn't run. Additionally, this can be caused by GIL deadlock or other reasons such as network errors or bugs in the communications library (e.g. NCCL), etc. +(MegatronPolicyWorker[rank=8] pid=3098870) [rank8]:[E129 16:38:47.802073551 ProcessGroupNCCL.cpp:1575] [PG ID 0 PG GUID 0(default_pg) Rank 8] ProcessGroupNCCL preparing to dump debug info. Include stack trace: 1 +(MegatronPolicyWorker[rank=0] pid=1892406, ip=10.65.5.217) [rank0]:[E129 16:38:47.634597716 ProcessGroupNCCL.cpp:1794] [PG ID 0 PG GUID 0(default_pg) Rank 0] Observed flight recorder dump signal from another rank via TCPStore. +(MegatronPolicyWorker[rank=0] pid=1892406, ip=10.65.5.217) [rank0]:[E129 16:38:47.634726597 ProcessGroupNCCL.cpp:1858] [PG ID 0 PG GUID 0(default_pg) Rank 0] Received a dump signal due to a collective timeout from rank 8 and we will try our best to dump the debug info. Last enqueued NCCL work: 62, last completed NCCL work: 62.This is most likely caused by incorrect usages of collectives, e.g., wrong sizes used across ranks, the order of collectives is not same for all ranks or the scheduled collective, for some reason, didn't run. Additionally, this can be caused by GIL deadlock or other reasons such as network errors or bugs in the communications library (e.g. NCCL), etc. +``` + +### DDP FIX EXPLANATION +## Why `disable_forward_pre_hook` is Called + +### The Context: `use_reference_model` + +The `use_reference_model` context manager (lines 1437-1486) temporarily **swaps the model weights** with the reference model weights: + +1. **On entry**: Copies the current model's state_dict to CPU, then loads the reference model's state_dict into the model +2. **On exit**: Restores the original model weights + +This allows running inference with the reference model's weights without having two full models in GPU memory. + +### What is the Forward Pre-Hook? + +Looking at the DDP code you attached, when **overlap_param_gather** is enabled with distributed optimizer: + +```376:386:/lustre/fsw/portfolios/coreai/users/shanmugamr/RL/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py + def enable_forward_pre_hook(self): + """ + Enable forward pre-hooks needed for param all-gather overlap with forward compute. + """ + assert self.use_forward_hook + assert len(self.remove_forward_pre_hook_handles) == 0 + # Register forward pre-hook for all sub-modules. + for module in self.module.modules(): + self.remove_forward_pre_hook_handles[module] = module.register_forward_pre_hook( + self._make_forward_pre_hook() + ) +``` + +The forward pre-hook is used to **overlap parameter all-gather with forward compute**. Here's how it works: + +1. With **distributed optimizer**, model parameters are **sharded across data-parallel ranks** (each rank only holds a portion of the parameters) +2. Before forward pass, parameters need to be **all-gathered** to reconstruct full parameters +3. The forward pre-hook intercepts each module's forward call to **wait for the all-gather to complete** for that module's parameters before executing + +```411:437:/lustre/fsw/portfolios/coreai/users/shanmugamr/RL/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py + def hook(module, *unused): + // ... + # Make sure all parameters in this module have been all-gathered as necessary. + for param in module.parameters(recurse=False): + # Skip parameters without an associated buffer + if param not in self.param_to_bucket_group: + continue + // ... + self.param_to_bucket_group[param].finish_param_sync( + skip_next_bucket_dispatch=skip_next_bucket_dispatch + ) +``` + +### Why Disable It During Weight Swap? + +When swapping weights in `use_reference_model`: + +```1459:1459:/lustre/fsw/portfolios/coreai/users/shanmugamr/RL/nemo_rl/models/policy/workers/megatron_policy_worker.py + self.model.load_state_dict(self.reference_state_dict, strict=True) +``` + +**The forward pre-hook would interfere because:** + +1. The hook maintains state about **which parameters have been all-gathered** via `param_to_bucket_group` +2. When you swap weights using `load_state_dict`, the parameter tensors are replaced, but the **bucket mappings become stale/invalid** +3. If the hook tries to call `finish_param_sync` on parameters that weren't properly set up for the new weights, it could cause errors or incorrect behavior +4. Additionally, `disable_forward_pre_hook` calls `start_param_sync(force_sync=True)` which **synchronously gathers all parameters** first, ensuring a clean state before the weight swap + +After the weight swap is complete (in `finally` block), the hooks are re-enabled: + +```1484:1486:/lustre/fsw/portfolios/coreai/users/shanmugamr/RL/nemo_rl/models/policy/workers/megatron_policy_worker.py + ## re-enable overlap param gather after weight swap + if self.should_disable_forward_pre_hook: + self.enable_forward_pre_hook() +``` + +### Summary + +| Step | What Happens | +|------|-------------| +| 1. `disable_forward_pre_hook()` | Removes all forward hooks, forces a synchronous param all-gather to ensure all params are complete | +| 2. `load_state_dict()` | Safely swap weights without hooks interfering | +| 3. Forward passes | Run without overlap optimization (simpler, safer) | +| 4. `enable_forward_pre_hook()` | Re-register hooks after restoring original weights | + +This is essentially a safety measure: disable the async overlap optimization during the weight-swapping operation to avoid race conditions or stale state in the DDP infrastructure. + +```python + def enable_forward_pre_hook(self): + """ + Enable forward pre-hooks needed for param all-gather overlap with forward compute. + """ + assert self.use_forward_hook + assert len(self.remove_forward_pre_hook_handles) == 0 + # Register forward pre-hook for all sub-modules. + for module in self.module.modules(): + self.remove_forward_pre_hook_handles[module] = module.register_forward_pre_hook( + self._make_forward_pre_hook() + ) +``` + +```python + def hook(module, *unused): + // ... + # Make sure all parameters in this module have been all-gathered as necessary. + for param in module.parameters(recurse=False): + # Skip parameters without an associated buffer + if param not in self.param_to_bucket_group: + continue + // ... + self.param_to_bucket_group[param].finish_param_sync( + skip_next_bucket_dispatch=skip_next_bucket_dispatch + ) +``` + +```python + self.model.load_state_dict(self.reference_state_dict, strict=True) +``` + +```python + ## re-enable overlap param gather after weight swap + if self.should_disable_forward_pre_hook: + self.enable_forward_pre_hook() +``` + +--- + +**User** + +When I run the application I get an error like this +``` + File "/opt/nemo-rl/examples/run_grpo_math.py", line 260, in + main() + File "/opt/nemo-rl/examples/run_grpo_math.py", line 243, in main + grpo_train( + File "/opt/nemo-rl/nemo_rl/algorithms/grpo.py", line 1339, in grpo_train + reference_logprobs = policy.get_reference_policy_logprobs( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/opt/nemo-rl/nemo_rl/models/policy/lm_policy.py", line 390, in get_reference_policy_logprobs + self.worker_group.get_all_worker_results(futures) + File "/opt/nemo-rl/nemo_rl/distributed/worker_groups.py", line 953, in get_all_worker_results + return future_bundle.get_results( + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/opt/nemo-rl/nemo_rl/distributed/worker_groups.py", line 103, in get_results + all_results = ray.get(object_refs) + ^^^^^^^^^^^^^^^^^^^^ + File "/opt/nemo_rl_venv/lib/python3.12/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper + return fn(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^ + File "/opt/nemo_rl_venv/lib/python3.12/site-packages/ray/_private/client_mode_hook.py", line 104, in wrapper + return func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/opt/nemo_rl_venv/lib/python3.12/site-packages/ray/_private/worker.py", line 2882, in get + values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/opt/nemo_rl_venv/lib/python3.12/site-packages/ray/_private/worker.py", line 968, in get_objects + raise value.as_instanceof_cause() +ray.exceptions.RayTaskError(AssertionError): ray::MegatronPolicyWorker.get_reference_policy_logprobs() (pid=710294, ip=10.65.29.29, actor_id=65a88a2a253f3b406b563ea701000000, repr=MegatronPolicyWorker[rank=0]) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/opt/nemo-rl/nemo_rl/utils/nsys.py", line 88, in wrapper + ret = func(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^ + File "/opt/nemo-rl/nemo_rl/models/policy/workers/base_policy_worker.py", line 143, in get_reference_policy_logprobs + with self.use_reference_model(): + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/root/.local/share/uv/python/cpython-3.12.12-linux-x86_64-gnu/lib/python3.12/contextlib.py", line 137, in __enter__ + return next(self.gen) + ^^^^^^^^^^^^^^ + File "/opt/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.py", line 1446, in use_reference_model + self.disable_forward_pre_hook() + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/opt/nemo-rl/nemo_rl/models/policy/workers/megatron_policy_worker.py", line 944, in disable_forward_pre_hook + self.model.disable_forward_pre_hook(param_sync=param_sync) + File "/opt/nemo-rl/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py", line 399, in disable_forward_pre_hook + assert len(self.remove_forward_pre_hook_handles) == 0, "remove_forward_pre_hook_handles is not empty, it is: " + str(self.remove_forward_pre_hook_handles) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +AssertionError: remove_forward_pre_hook_handles is not empty, it is: {CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): , CudaGraphManager(): } +``` +What does this error mean ? + +--- + +**Cursor** + +```plaintext +AssertionError: remove_forward_pre_hook_handles is not empty, it is: {CudaGraphManager(): ...} +``` + +```python +assert len(self.remove_forward_pre_hook_handles) == 0, "remove_forward_pre_hook_handles is not empty, it is: " + str(self.remove_forward_pre_hook_handles) +``` + +```python +def disable_forward_pre_hook(self, param_sync: bool = True): + """ + Disable forward pre-hooks needed for param all-gather overlap with forward compute. + Skip synchronous param all-gather if `param_sync` is False. + """ + assert self.use_forward_hook + # De-register forward pre-hook for all sub-modules. + for module in self.module.modules(): + assert self.remove_forward_pre_hook_handles[module] is not None + self.remove_forward_pre_hook_handles[module].remove() + del self.remove_forward_pre_hook_handles[module] + assert len(self.remove_forward_pre_hook_handles) == 0, "remove_forward_pre_hook_handles is not empty, it is: " + str(self.remove_forward_pre_hook_handles) + ... +``` + +This error is revealing a **mismatch between hook registration and deregistration** due to `CudaGraphManager` modules. Let me explain: + +## What the Error Means + +The assertion failure happens in `disable_forward_pre_hook`: + +```388:399:/lustre/fsw/portfolios/coreai/users/shanmugamr/RL/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py + def disable_forward_pre_hook(self, param_sync: bool = True): + """ + Disable forward pre-hooks needed for param all-gather overlap with forward compute. + Skip synchronous param all-gather if `param_sync` is False. + """ + assert self.use_forward_hook + # De-register forward pre-hook for all sub-modules. + for module in self.module.modules(): + assert self.remove_forward_pre_hook_handles[module] is not None + self.remove_forward_pre_hook_handles[module].remove() + del self.remove_forward_pre_hook_handles[module] + assert len(self.remove_forward_pre_hook_handles) == 0, "remove_forward_pre_hook_handles is not empty, it is: " + str(self.remove_forward_pre_hook_handles) +``` + +### The Problem + +1. **When `enable_forward_pre_hook` was called**: It iterated through `self.module.modules()` and registered hooks for **ALL** modules, including `CudaGraphManager` objects + +2. **When `disable_forward_pre_hook` is called**: It again iterates through `self.module.modules()` to remove hooks + +3. **The mismatch**: The `CudaGraphManager` modules that were present during registration are **no longer being returned** by `self.module.modules()` during deregistration. This leaves their handles in the dictionary, causing the assertion to fail. + +### Why This Happens + +Looking at the leftover keys in the error: +``` +{CudaGraphManager(): , + CudaGraphManager(): , + ...} +``` + +These `CudaGraphManager` objects were registered as submodules when `enable_forward_pre_hook` ran, but they're either: +- Dynamically removed from the module tree after registration +- Not yielded by `modules()` during the second iteration for some reason +- Created/destroyed dynamically as part of CUDA graph management + +### Root Cause + +This is likely a **bug or incompatibility** between: +1. **CUDA Graphs** being enabled in your configuration (`cuda_graph_impl = "local"` as I saw earlier in the code) +2. **Overlap param gather** being enabled (`overlap_param_gather = True`) + +When CUDA graphs are used, the model may dynamically add/remove `CudaGraphManager` wrapper modules, which breaks the assumption that `modules()` returns the same set both times. + +### Potential Fixes + +**Option 1**: Disable CUDA graphs when using reference model weight swapping (if that's acceptable for your use case) + +**Option 2**: Fix the `disable_forward_pre_hook` to be more robust by clearing ALL handles, not just those from current `modules()`: + +```python +def disable_forward_pre_hook(self, param_sync: bool = True): + assert self.use_forward_hook + # Remove all registered handles + for module, handle in list(self.remove_forward_pre_hook_handles.items()): + handle.remove() + self.remove_forward_pre_hook_handles.clear() + + # Force synchronize parameters. + if param_sync: + self.start_param_sync(force_sync=True) +``` + +**Option 3**: Skip calling `disable_forward_pre_hook` when reference model doesn't need it (if `init_reference_model=False` was used or reference model was never initialized) + +Would you like me to help you implement one of these fixes? + +```python + def disable_forward_pre_hook(self, param_sync: bool = True): + """ + Disable forward pre-hooks needed for param all-gather overlap with forward compute. + Skip synchronous param all-gather if `param_sync` is False. + """ + assert self.use_forward_hook + # De-register forward pre-hook for all sub-modules. + for module in self.module.modules(): + assert self.remove_forward_pre_hook_handles[module] is not None + self.remove_forward_pre_hook_handles[module].remove() + del self.remove_forward_pre_hook_handles[module] + assert len(self.remove_forward_pre_hook_handles) == 0, "remove_forward_pre_hook_handles is not empty, it is: " + str(self.remove_forward_pre_hook_handles) +``` + +```plaintext +{CudaGraphManager(): , + CudaGraphManager(): , + ...} +``` + +```python +def disable_forward_pre_hook(self, param_sync: bool = True): + assert self.use_forward_hook + # Remove all registered handles + for module, handle in list(self.remove_forward_pre_hook_handles.items()): + handle.remove() + self.remove_forward_pre_hook_handles.clear() + + # Force synchronize parameters. + if param_sync: + self.start_param_sync(force_sync=True) +``` + +```python +def toggle_cuda_graphs(model, set_to="none", reset_cuda_graphs=True): + ... + elif attribute == "cudagraph_manager": + for module in modules: + if set_to == "local": + if reset_cuda_graphs: + from megatron.core.transformer.cuda_graphs import CudaGraphManager + # If we are resetting cuda graphs we create a new cuda graph manager + setattr(module[0], attribute, CudaGraphManager(model.config)) + else: + # If we are not resetting cuda graphs we set it to its cached cuda graph + setattr(module[0], attribute, module[1]) + else: + for module in modules: + # If we are deleting the cuda graph, we delete its attribute + if hasattr(module[0], "cudagraph_manager"): + delattr(module[0], "cudagraph_manager") + ... +``` + +```python +setattr(module[0], attribute, CudaGraphManager(model.config)) +``` + +```python +delattr(module[0], "cudagraph_manager") +``` + +```python +toggle_cuda_graphs(lang_module, set_to="none", reset_cuda_graphs=True) + + +``` + +```plaintext +Two potential solutions emerge: +1. Disable hooks before structural module changes +2. Modify hook removal logic to track references independently of module hierarchy +``` + +Let me check where `toggle_cuda_graphs` is being called in relation to the hook operations: + +**Yes, this is exactly the cause!** Let me explain the problem: + +## Root Cause Analysis + +### The Sequence of Events + +1. **During `generate()`** (line 1875): + ```python + toggle_cuda_graphs(lang_module, set_to="local", reset_cuda_graphs=True) + ``` + This **creates new `CudaGraphManager` objects** and adds them as submodules: + ```python + # From toggle_cuda_graphs when set_to="local": + setattr(module[0], attribute, CudaGraphManager(model.config)) + ``` + +2. **At some point**, `enable_forward_pre_hook()` is called, which registers hooks for **ALL** current modules, including the newly created `CudaGraphManager` objects: + + ```376:386:/lustre/fsw/portfolios/coreai/users/shanmugamr/RL/3rdparty/Megatron-LM-workspace/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py + def enable_forward_pre_hook(self): + # ... + for module in self.module.modules(): + self.remove_forward_pre_hook_handles[module] = module.register_forward_pre_hook( + self._make_forward_pre_hook() + ) + ``` + +3. **At end of `generate()`** (line 2045): + ```python + toggle_cuda_graphs(lang_module, set_to="none", reset_cuda_graphs=True) + ``` + This **DELETES the `CudaGraphManager` attributes** from their parent modules: + ```python + # From toggle_cuda_graphs when set_to="none": + if hasattr(module[0], "cudagraph_manager"): + delattr(module[0], "cudagraph_manager") # ← REMOVES from module tree! + ``` + +4. **When `get_reference_policy_logprobs` → `use_reference_model` → `disable_forward_pre_hook()` is called**: + It iterates through `self.module.modules()`, but the `CudaGraphManager` objects are **no longer in the module tree** (they were deleted in step 3). + +5. **Result**: The hook handles for `CudaGraphManager` objects remain in `remove_forward_pre_hook_handles` but can't be removed → **Assertion fails!** + +## Visual Timeline + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ generate() starts │ +│ └─ toggle_cuda_graphs(set_to="local") │ +│ └─ Creates CudaGraphManager as submodule │ +├─────────────────────────────────────────────────────────────────────┤ +│ enable_forward_pre_hook() called somewhere │ +│ └─ Registers hooks for ALL modules (including CudaGraphManager) │ +│ └─ remove_forward_pre_hook_handles[CudaGraphManager] = handle │ +├─────────────────────────────────────────────────────────────────────┤ +│ generate() ends │ +│ └─ toggle_cuda_graphs(set_to="none") │ +│ └─ delattr(module, "cudagraph_manager") ← DELETES IT! │ +├─────────────────────────────────────────────────────────────────────┤ +│ get_reference_policy_logprobs() │ +│ └─ use_reference_model() │ +│ └─ disable_forward_pre_hook() │ +│ └─ for module in self.module.modules(): ← CudaGraphManager│ +│ # is NOT here anymore! │ +│ └─ ASSERTION FAILS: handles dict not empty! │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +## Solution Options + +**Option 1**: Fix `disable_forward_pre_hook` in DDP to be robust (best fix): + +```python +def disable_forward_pre_hook(self, param_sync: bool = True): + assert self.use_forward_hook + # Remove ALL registered handles (not just current modules) + for module, handle in list(self.remove_forward_pre_hook_handles.items()): + handle.remove() + self.remove_forward_pre_hook_handles.clear() + + if param_sync: + self.start_param_sync(force_sync=True) +``` + +**Option 2**: Call `disable_forward_pre_hook` BEFORE `toggle_cuda_graphs(set_to="none")` in your code. + +**Option 3**: Don't delete cuda graph managers, just disable them (modify `toggle_cuda_graphs`). + +Would you like me to implement Option 1 as a fix in the Megatron DDP code? + +```python + toggle_cuda_graphs(lang_module, set_to="local", reset_cuda_graphs=True) +``` + +```python + # From toggle_cuda_graphs when set_to="local": + setattr(module[0], attribute, CudaGraphManager(model.config)) +``` + +```python + def enable_forward_pre_hook(self): + # ... + for module in self.module.modules(): + self.remove_forward_pre_hook_handles[module] = module.register_forward_pre_hook( + self._make_forward_pre_hook() + ) +``` + +```python + toggle_cuda_graphs(lang_module, set_to="none", reset_cuda_graphs=True) +``` + +```python + # From toggle_cuda_graphs when set_to="none": + if hasattr(module[0], "cudagraph_manager"): + delattr(module[0], "cudagraph_manager") # ← REMOVES from module tree! +``` + +```plaintext +┌─────────────────────────────────────────────────────────────────────┐ +│ generate() starts │ +│ └─ toggle_cuda_graphs(set_to="local") │ +│ └─ Creates CudaGraphManager as submodule │ +├─────────────────────────────────────────────────────────────────────┤ +│ enable_forward_pre_hook() called somewhere │ +│ └─ Registers hooks for ALL modules (including CudaGraphManager) │ +│ └─ remove_forward_pre_hook_handles[CudaGraphManager] = handle │ +├─────────────────────────────────────────────────────────────────────┤ +│ generate() ends │ +│ └─ toggle_cuda_graphs(set_to="none") │ +│ └─ delattr(module, "cudagraph_manager") ← DELETES IT! │ +├─────────────────────────────────────────────────────────────────────┤ +│ get_reference_policy_logprobs() │ +│ └─ use_reference_model() │ +│ └─ disable_forward_pre_hook() │ +│ └─ for module in self.module.modules(): ← CudaGraphManager│ +│ # is NOT here anymore! │ +│ └─ ASSERTION FAILS: handles dict not empty! │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +```python +def disable_forward_pre_hook(self, param_sync: bool = True): + assert self.use_forward_hook + # Remove ALL registered handles (not just current modules) + for module, handle in list(self.remove_forward_pre_hook_handles.items()): + handle.remove() + self.remove_forward_pre_hook_handles.clear() + + if param_sync: + self.start_param_sync(force_sync=True) +``` + +AFTER FIX 1 +### What Changed + +| Before | After | +|--------|-------| +| Iterated over `self.module.modules()` | Iterates over `self.remove_forward_pre_hook_handles.items()` directly | +| Asserted each module exists in dict | No assertion needed - we're iterating over what we have | +| Deleted entries one by one | Uses `clear()` after removing all handles | +| Would fail if modules were removed from tree | Handles dynamically removed modules (like `CudaGraphManager`) | + +This fix ensures that all registered hook handles are properly removed regardless of whether the modules are still part of the module tree. You can now run your application again to test it. + +```python + def disable_forward_pre_hook(self, param_sync: bool = True): + """ + Disable forward pre-hooks needed for param all-gather overlap with forward compute. + Skip synchronous param all-gather if `param_sync` is False. + """ + assert self.use_forward_hook + # De-register forward pre-hook for all registered modules. + # We iterate over the dict keys directly (instead of self.module.modules()) to handle + # cases where modules may have been dynamically removed from the module tree + # (e.g., CudaGraphManager objects removed by toggle_cuda_graphs). + for module, handle in list(self.remove_forward_pre_hook_handles.items()): + handle.remove() + self.remove_forward_pre_hook_handles.clear() + + # Force synchronize parameters. + if param_sync: + self.start_param_sync(force_sync=True) +``` \ No newline at end of file diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 144683c95c..b234640789 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -564,7 +564,13 @@ def train( def generate( self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False ) -> BatchedDataDict[GenerationOutputSpec]: - """Generate a batch of data using the policy.""" + """Generate a batch of data using the policy. + + For coordinator-based inference (Megatron backend), all data is sent to DP rank 0 + only, which submits requests to the coordinator. The coordinator then distributes + work across all DP engines. Other DP ranks participate in the engine loop but + don't receive input data directly. + """ # Verify input data is right-padded assert isinstance(data, BatchedDataDict), ( f"data must be a BatchedDataDict, got type: {type(data)}" @@ -573,14 +579,24 @@ def generate( "Missing required input fields" ) - dp_size = self.sharding_annotations.get_axis_size("data_parallel") - sharded_data = data.shard_by_batch_size(dp_size, batch_size=None) + # For coordinator-based inference: send ALL data to DP rank 0 only. + # Other DP ranks are called with data=None but still participate in the + # inference engine loop. The coordinator handles load balancing across DP ranks. + # + # With in_sharded_axes=[] and data_parallel not in replicate_on_axes, + # data_parallel becomes a "free axis". Only workers at DP coord 0 receive data, + # while workers at other DP coords get None (via make_dummy_calls_to_free_axes). futures = self.worker_group.run_all_workers_sharded_data( "generate", - data=sharded_data, - in_sharded_axes=["data_parallel"], + data=data, # Full data goes to DP=0 only (free axis behavior) + in_sharded_axes=[], # No sharding - data_parallel is a "free axis" replicate_on_axes=["tensor_parallel", "pipeline_parallel"], - output_is_replicated=["tensor_parallel", "pipeline_parallel"], + output_is_replicated=[ + "data_parallel", # Only DP rank 0 returns results + "tensor_parallel", + "pipeline_parallel", + ], + make_dummy_calls_to_free_axes=True, # Call all DP ranks, but only DP=0 gets data common_kwargs={"greedy": greedy}, ) assert self.cfg["generation"] is not None, "Generation config is not set" diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 1d175f35b2..bcbb392a1e 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import gc import os import re @@ -55,6 +56,7 @@ ) from megatron.bridge.training.state import GlobalState from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer +from megatron.bridge.training.utils.pg_utils import get_pg_collection from megatron.bridge.training.utils.train_utils import ( logical_and_across_model_parallel_group, reduce_max_stat_across_model_parallel_group, @@ -86,10 +88,13 @@ is_pipeline_last_stage, ) from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.transformer import MegatronModule +# Note: delete_cuda_graphs is called internally by toggle_cuda_graphs when reset_cuda_graphs=True from megatron.core.transformer.module import Float16Module from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.utils import toggle_cuda_graphs from megatron.training.utils import get_ltor_masks_and_position_ids from ray.util.queue import Queue from transformers import PreTrainedTokenizerBase @@ -156,9 +161,6 @@ class MegatronGenerationConfig(TypedDict): # Enable CUDA graphs for prefill/context processing use_cuda_graphs_for_non_decode_steps: bool # Split long prefills into chunks for better memory management - enable_chunked_prefill: bool - # Unified memory usage level (0=disabled, higher values enable more aggressive paging) - unified_memory_level: int # Maximum number of tokens to use in a single step. Analogous to vllm's max_num_batched_tokens. # Can cause OOM if set too high so should be tuned with buffer_size_gb if OOMing. If set too # low, then will only do 512 tokens at a time, which can be slow. @@ -339,6 +341,7 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]: data_parallel_random_init=cfg.rng.data_parallel_random_init, pre_wrap_hook=pre_wrap_hook, mixed_precision_wrapper=mixed_precision_wrapper, + pg_collection=ProcessGroupCollection.use_mpu_process_groups(), ) if load_optimizer: @@ -523,6 +526,15 @@ def __init__( if not self.is_generation_colocated: os.environ["NCCL_CUMEM_ENABLE"] = "1" + self.dynamic_inference_engine = None + self.inference_client = None + self.inference_context = None + self.inference_wrapped_model = None + self._inference_engine_initialized = False + self._inference_engine_paused = True # Start paused since we begin with training + self._inference_loop = None # Event loop for inference operations + self._inference_thread = None # Thread running the event loop + self.cfg = config dtype_map = { "float32": torch.float32, @@ -615,6 +627,12 @@ def __init__( model_cfg = cfg_from_pretrained.model cfg_from_pretrained.logger = LoggerConfig() + # Ensure make_vocab_size_divisible_by has a reasonable default (128 is standard) + if not hasattr(model_cfg, 'make_vocab_size_divisible_by') or model_cfg.make_vocab_size_divisible_by is None: + model_cfg.make_vocab_size_divisible_by = 128 + if get_rank_safe() == 0: + print(f"[WARNING] make_vocab_size_divisible_by not found in config, defaulting to 128") + model_cfg.tensor_model_parallel_size = self.cfg["megatron_cfg"][ "tensor_model_parallel_size" ] @@ -645,6 +663,7 @@ def __init__( # Setting moe_router_dtype to higher precision (e.g. fp64) can improve numerical stability, # especially when using many experts. model_cfg.moe_router_dtype = self.cfg["megatron_cfg"]["moe_router_dtype"] + model_cfg.moe_token_dispatcher_type = "alltoall" # The below two configs (and "freeze_moe_router") are used to stabilize moe training # by preventing updates to the moe router. We found that this is helpful in reducing @@ -761,6 +780,14 @@ def __init__( model_cfg.calculate_per_token_loss = True model_cfg.perform_initialization = True + # CUDA graphs: Set to "local" so CudaGraphManager is created on TransformerLayers. + # The toggle_cuda_graphs() function will switch between "local" (inference) and "none" (training). + # This is required because cudagraph_manager is only created if cuda_graph_impl=="local" at model init. + model_cfg.cuda_graph_impl = "local" + model_cfg.cuda_graph_scope = None + model_cfg.use_te_rng_tracker = True + model_cfg.inference_rng_tracker = True + assert ( "aux_loss" not in model_cfg.moe_router_load_balancing_type or model_cfg.moe_aux_loss_coeff == 0 @@ -876,6 +903,9 @@ def __init__( ref_ckpt_context = init_checkpointing_context(ref_checkpoint_config) # Create a separate megatron config for the reference model with the correct checkpoint config + self.megatron_cfg.model.cuda_graph_impl = "none" + self.megatron_cfg.model.use_te_rng_tracker = False + self.megatron_cfg.model.inference_rng_tracker = False ref_megatron_cfg = ConfigContainer( model=self.megatron_cfg.model, checkpoint=ref_checkpoint_config, # Use the reference checkpoint config @@ -904,6 +934,7 @@ def __init__( overlap_param_gather_with_optimizer_step=self.megatron_cfg.optimizer.overlap_param_gather_with_optimizer_step, pre_wrap_hook=self.megatron_cfg.rng.data_parallel_random_init, mixed_precision_wrapper=ref_mixed_precision_wrapper, + pg_collection=ProcessGroupCollection.use_mpu_process_groups(), ) print("Loading the Reference Model") if ( @@ -964,6 +995,7 @@ def __init__( self.megatron_cfg.model.make_vocab_size_divisible_by, self.cfg["megatron_cfg"]["tensor_model_parallel_size"], ) + self.dp_size = worker_sharding_annotations.get_axis_size("data_parallel") self.megatron_bridge = AutoBridge.from_hf_pretrained( hf_model_name, trust_remote_code=True @@ -999,6 +1031,362 @@ def disable_forward_pre_hook(self, param_sync=True): assert isinstance(self.model, DistributedDataParallel) self.model.disable_forward_pre_hook(param_sync=param_sync) + def _get_lang_module(self): + """Get the underlying language module from the wrapped model.""" + return ( + self.model.module.module + if hasattr(self.model.module, "module") + else self.model.module + ) + + def _initialize_inference_engine(self, mcore_generation_config: dict): + """Initialize the persistent inference engine and client. + + This method sets up the DynamicInferenceEngine, DynamicInferenceContext, + and InferenceClient for coordinator-based inference. The engine is created + once and reused across multiple generate() calls. + """ + if self._inference_engine_initialized: + return + + from megatron.core.inference.contexts.dynamic_context import ( + DynamicInferenceContext, + ) + from megatron.core.inference.engines.dynamic_engine import ( + DynamicInferenceEngine, + ) + from megatron.core.inference.inference_client import InferenceClient + from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( + GPTInferenceWrapper, + ) + from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, + ) + + model_cfg = self.megatron_cfg.model + + + from megatron.core.utils import get_attr_wrapped_model, get_pg_size + pg_collection = get_attr_wrapped_model(self.model, "pg_collection") + tp_group = getattr(pg_collection, 'tp', None) if pg_collection is not None else None + pp_group = getattr(pg_collection, 'pp', None) if pg_collection is not None else None + ep_group = getattr(pg_collection, 'ep', None) if pg_collection is not None else None + + # Set defaults in case groups are None + inference_tp_size = get_pg_size(tp_group) if tp_group is not None else 1 + inference_pp_size = get_pg_size(pp_group) if pp_group is not None else 1 + inference_ep_size = get_pg_size(ep_group) if ep_group is not None else 1 + + print(f'inference_tp_size: {inference_tp_size}, inference_pp_size: {inference_pp_size}, inference_ep_size: {inference_ep_size}') + + # Determine if we need MoE expert padding for CUDA graphs + # This is required when using CUDA graphs with expert parallelism (EP > 1) for MoE models + # as some MoE routers have D2H sync that breaks CUDA graph capture + is_moe_model = getattr(model_cfg, 'num_moe_experts', None) is not None and model_cfg.num_moe_experts > 1 + has_expert_parallelism = inference_ep_size > 1 if ep_group is not None else False + moe_pad_for_cuda_graphs = is_moe_model and has_expert_parallelism + + if is_moe_model and self.rank == 0: + print(f"[INFO] MoE model detected: num_experts={model_cfg.num_moe_experts}, " + f"EP_size={inference_ep_size}, moe_pad_for_cuda_graphs={moe_pad_for_cuda_graphs}") + + inference_wrapper_config = InferenceWrapperConfig( + hidden_size=model_cfg.hidden_size, + inference_batch_times_seqlen_threshold=1000000, + fp32_residual_connection=model_cfg.fp32_residual_connection, + params_dtype=model_cfg.params_dtype, + padded_vocab_size=self.final_padded_vocab_size, + inference_max_seq_length=self.cfg["generation"]["max_new_tokens"], + inference_max_requests=self.cfg["generation_batch_size"], + moe_pad_experts_for_cuda_graph_inference=moe_pad_for_cuda_graphs or is_moe_model + ) + + buffer_size_gb = mcore_generation_config["buffer_size_gb"] + num_cuda_graphs = mcore_generation_config["num_cuda_graphs"] + block_size_tokens = mcore_generation_config["block_size_tokens"] + enable_cuda_graph = mcore_generation_config["enable_cuda_graph"] + enable_chunked_prefill = mcore_generation_config["enable_chunked_prefill"] + use_cuda_graphs_for_non_decode_steps = mcore_generation_config[ + "use_cuda_graphs_for_non_decode_steps" + ] + max_tokens = mcore_generation_config["max_tokens"] + # Read unified_memory_level from config (default to 0 for compatibility) + # Level 0: No unified memory, CUDA graphs are deleted/recreated on pause/resume + # Level 1+: Unified memory enabled, tensors maintain static addresses + unified_memory_level = mcore_generation_config["unified_memory_level"] + model_config = self.model.config + # Enable CUDA graphs for inference + model_config.cuda_graph_impl = "local" + + # Create inference context + self.inference_context = DynamicInferenceContext( + params_dtype=inference_wrapper_config.params_dtype, + num_layers=model_config.num_layers, + kv_channels=model_config.kv_channels, + num_attention_heads=model_config.num_query_groups, + max_sequence_length=self.cfg["generation"]["max_new_tokens"], + buffer_size_gb=buffer_size_gb, + materialize_only_last_token_logits=False, + num_cuda_graphs=num_cuda_graphs, + block_size_tokens=block_size_tokens, + pg_collection=pg_collection, + use_cuda_graphs_for_non_decode_steps=use_cuda_graphs_for_non_decode_steps, + use_flashinfer_fused_rope=False, + unified_memory_level=unified_memory_level, + max_tokens=max_tokens, + persist_cuda_graphs=False, + ) + + # Create inference wrapper + self.inference_wrapped_model = GPTInferenceWrapper( + self.model, inference_wrapper_config, self.inference_context, pg_collection=pg_collection + ) + self.inference_wrapped_model.prep_model_for_inference() + self.inference_wrapped_model.model_is_pipeline_parallel = ( + self.cfg["megatron_cfg"]["pipeline_model_parallel_size"] > 1 + ) + + # Create text generation controller + pp_group = getattr(pg_collection, "pp", None) + text_generation_controller = TextGenerationController( + inference_wrapped_model=self.inference_wrapped_model, + tokenizer=self.megatron_tokenizer, + pp_group=pp_group, + ) + + # Calculate seed based on node and rank + local_rank = torch.cuda.current_device() + num_gpus_per_node = torch.cuda.device_count() + node_idx = self.rank // num_gpus_per_node if num_gpus_per_node > 0 else 0 + seed = (node_idx * 1024) + local_rank + + # Create the inference engine + self.dynamic_inference_engine = DynamicInferenceEngine( + text_generation_controller, + self.inference_context, + random_seed=seed, + enable_cuda_graph=enable_cuda_graph, + track_paused_request_events=False, + enable_chunked_prefill=enable_chunked_prefill, + inference_logging_step_interval=0, + pg_collection=pg_collection, + ) + + self._inference_engine_initialized = True + self._inference_engine_paused = True # Engine starts in paused state + print(f"[Rank {self.rank}] Initialized persistent inference engine") + + async def _start_inference_coordinator(self, coordinator_port: int): + """Start the inference coordinator and engine loop. + + This is called once when the inference infrastructure is first needed. + """ + await self.dynamic_inference_engine.start_listening_to_data_parallel_coordinator( + inference_coordinator_port=coordinator_port, + launch_inference_coordinator=True, + ) + + dist_rank = torch.distributed.get_rank() + if dist_rank == 0: + from megatron.core.inference.inference_client import InferenceClient + self.inference_client = InferenceClient(coordinator_port) + await self.inference_client.start() + + self._inference_engine_paused = False + + def pause_inference_engine(self): + """pause the inference engine to free GPU memory for training. + + This method should be called before training to: + 1. Deallocate KV cache and other inference-specific GPU memory + 2. Disable CUDA graphs for inference + 3. Toggle model configuration for training mode + + Uses the coordinator's pause mechanism to properly pause the engine loop + and then pause the engine (deallocate tensors, etc.). + + For coordinator-based inference: + - Only rank 0 sends pause signals via the coordinator + - The coordinator broadcasts to all DP engines + - Non-rank-0 workers wait for their engine to be paused via the event loop + """ + + future = asyncio.run_coroutine_threadsafe( + self.pause_engine(), + self._inference_loop + ) + future.result() + # Synchronize all ranks + torch.distributed.barrier() + + self._inference_engine_paused = True + print(f"[Rank {self.rank}] paused inference engine") + + async def pause_engine(self): + """Send pause signals via the coordinator and wait for acknowledgment.""" + if torch.distributed.get_rank() == 0: + # Send PAUSE signals + self.inference_client.pause_engines() + # Wait for the engine to acknowledge the pause + await self.dynamic_inference_engine.paused.wait() + + def resume_inference_engine(self): + """Resume the inference engine after training. + + This method should be called before generation to: + 1. Reallocate KV cache and inference-specific GPU memory + 2. Enable CUDA graphs for inference + 3. Toggle model configuration for inference mode + + Uses the coordinator's resume mechanism to properly resume the engine loop. + + For coordinator-based inference: + - Only rank 0 sends resume signals via the coordinator + - The coordinator broadcasts to all DP engines + - Non-rank-0 workers wait for their engine to be running via the event loop + """ + + # Use the coordinator-based resume mechanism + # Only rank 0 sends the signal - coordinator broadcasts to all DP engines + future = asyncio.run_coroutine_threadsafe( + self.resume_engine(), + self._inference_loop + ) + future.result() + # Synchronize all ranks + torch.distributed.barrier() + + self._inference_engine_paused = False + print(f"[Rank {self.rank}] Resumed inference engine") + + async def resume_engine(self): + """Send resume signals via the coordinator and wait for acknowledgment.""" + if torch.distributed.get_rank() == 0: + # Send RESUME then UNPAUSE signals + self.inference_client.unpause_engines() + # Wait for the engine to acknowledge it's running + await self.dynamic_inference_engine.running.wait() + + + @contextmanager + def inference_mode(self, mcore_generation_config: dict): + """Context manager for inference mode, following Megatron RL's pattern. + + This mirrors megatron_rl_inference_mode from megatron/rl/rl_utils.py + + ENTER order: + 1. Put model in eval mode + 2. Clear rotary cache + 3. Toggle CUDA graphs ON + 4. Initialize/get inference engine + 5. Restore KV cache from CPU (if offloaded) + 6. Resume engine (engine handles CUDA graph creation internally if needed) + + EXIT order (matching Megatron RL): + 1. Pause engine + 2. Offload KV cache to CPU + 3. Toggle CUDA graphs OFF + 4. Clear rotary cache + 5. Put model back in train mode + + Args: + reset_cuda_graphs: If True, recreate CUDA graphs each time. If False, reuse cached. + + Yields: + The dynamic inference engine for use during inference. + """ + # Get the language module (unwrap from precision wrappers if needed) + lang_module = ( + self.model.module.module + if hasattr(self.model.module, "module") + else self.model.module + ) + + # Get config settings + offload_kv_cache = mcore_generation_config.get("offload_kv_cache_during_training", False) + reset_cuda_graphs = mcore_generation_config.get("reset_cuda_graphs", False) + + # Critical assertion from Megatron RL: + # If offloading KV cache, MUST reset CUDA graphs (addresses change on CPU->GPU) + #if offload_kv_cache: + # assert reset_cuda_graphs, ( + # "reset_cuda_graphs must be True when offloading kv cache during training. " + # "Memory addresses change when moving CPU->GPU, invalidating captured CUDA graphs." + # ) + + # Save training state + was_training = lang_module.training + + # === ENTER INFERENCE MODE === + + # 1. Put model in eval mode + lang_module.eval() + + # 2. Clear rotary position embedding caches (Megatron RL does this) + rotary_module = getattr(lang_module, "rotary_pos_emb", None) + has_lru_cache = rotary_module is not None and hasattr(rotary_module.forward, "cache_parameters") + if has_lru_cache: + rotary_module.forward.cache_clear() + + toggle_cuda_graphs(lang_module, set_to="local", reset_cuda_graphs=reset_cuda_graphs) + + # 4. Initialize inference engine if not already done + if not self._inference_engine_initialized: + self._initialize_inference_engine(mcore_generation_config) + # Start the coordinator and engine loop (first time only) + coordinator_port = self.cfg["generation"].get( + "inference_coordinator_port", 5995 + ) + self._run_async_coordinator_start(coordinator_port) + + # 5. Handle KV cache restoration (before CUDA graph creation) + if offload_kv_cache and hasattr(self, 'inference_context'): + if self.inference_context.memory_buffer is not None: + kv_cache = self.inference_context.memory_buffer + if not kv_cache.is_cuda: + if self.rank == 0: + cache_size_gb = kv_cache.numel() * kv_cache.element_size() / (1024**3) + print(f"[INFO] Restoring KV cache ({cache_size_gb:.2f} GB) to GPU") + self.inference_context.memory_buffer = kv_cache.cuda() + + if self._inference_engine_paused: + self.resume_inference_engine() + + try: + # Yield the inference engine for use + yield self.dynamic_inference_engine + + finally: + + # 1. pause the inference engine + if self._inference_engine_initialized and not self._inference_engine_paused: + self.pause_inference_engine() + + # 2. Handle KV cache offloading AFTER pause (matching Megatron RL) + if offload_kv_cache and hasattr(self, 'inference_context'): + if self.inference_context.memory_buffer is not None: + kv_cache = self.inference_context.memory_buffer + if kv_cache.is_cuda: + if self.rank == 0: + cache_size_gb = kv_cache.numel() * kv_cache.element_size() / (1024**3) + print(f"[INFO] Offloading KV cache ({cache_size_gb:.2f} GB) to CPU") + self.inference_context.memory_buffer = kv_cache.cpu() + + # 3. Toggle CUDA graphs OFF + toggle_cuda_graphs(lang_module, set_to="none", reset_cuda_graphs=reset_cuda_graphs) + + # 4. Clear rotary embedding cache again (Megatron RL does this on exit too) + if has_lru_cache: + rotary_module.forward.cache_clear() + + # 5. Restore training state + if was_training: + lang_module.train() + + # 6. Force garbage collection and CUDA memory cleanup + gc.collect() + torch.cuda.empty_cache() + @wrap_with_nvtx_name("megatron_policy_worker/train") def train( self, @@ -1009,6 +1397,16 @@ def train( mbs: Optional[int] = None, ) -> dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" + """ + lang_module = ( + self.model.module.module + if hasattr(self.model.module, "module") + else self.model.module + ) + + toggle_cuda_graphs(lang_module, set_to="none", reset_cuda_graphs=False) + """ + self.model.zero_grad_buffer() if hasattr(self.model, "inference_params"): self.model.inference_params = None @@ -1171,18 +1569,20 @@ def train( else: update_successful, grad_norm, num_zeros_in_grad = (True, 0.0, 0.0) + pg_collection = get_pg_collection(self.model) + # when freezing sub-models we may have a mixture of successful and unsucessful ranks, # so we must gather across mp ranks update_successful = logical_and_across_model_parallel_group( - update_successful + update_successful, mp_group=pg_collection.mp ) # grad_norm and num_zeros_in_grad will be None on ranks without trainable params, # so we must gather across mp ranks grad_norm: float = reduce_max_stat_across_model_parallel_group( - grad_norm + grad_norm, mp_group=pg_collection.mp ) num_zeros_in_grad: float = reduce_max_stat_across_model_parallel_group( - num_zeros_in_grad + num_zeros_in_grad, mp_group=pg_collection.mp ) if update_successful: @@ -1829,192 +2229,126 @@ def collection_fn(_): def generate( self, *, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False ) -> BatchedDataDict[GenerationOutputSpec]: - """Generate a batch of data using huggingface framework generation. + """Generate a batch of data using Megatron Core inference with coordinator. + + This method uses the coordinator-based inference pattern from Megatron Core, + which enables better parallelism across data-parallel ranks through a central + coordinator that routes requests to available engines. + + The inference engine is created once and reused across generate() calls. + The engine is paused between generate() calls to free GPU memory for training. + + For coordinator-based inference: + - Only DP rank 0 receives actual data and submits requests to the coordinator + - Other DP ranks receive data=None but still participate in the inference engine loop + - The coordinator distributes work across all DP engines + - Results are broadcast from rank 0 to all ranks Args: - data: BatchedDataDict containing input_ids and input_lengths tensors + data: BatchedDataDict containing input_ids and input_lengths tensors, + or None for non-DP-0 workers (they participate in engine loop only) Returns: BatchedDataDict conforming to GenerationOutputSpec: - output_ids: input + generated token IDs - logprobs: Log probabilities for each token - generation_lengths: Lengths of each response """ - # 512 bATCH SIZE (200 tokens) - no_grad = torch.no_grad() - no_grad.__enter__() + from megatron.core.inference.sampling_params import SamplingParams + self.model.config.flash_decode = False if self.should_disable_forward_pre_hook: self.model = self.move_model( self.model, "cuda", move_params=True, move_grads=False ) - # Verify input is right padded - assert isinstance(data, BatchedDataDict), ( - f"data must be a BatchedDataDict, got type: {type(data)}" - ) - assert "input_ids" in data and "input_lengths" in data, ( - f"input_ids and input_lengths must be present in the BatchedDataDict, got keys: {data.keys()}" - ) - is_right_padded, error_msg = verify_right_padding( - data, pad_value=self.tokenizer.pad_token_id - ) - if not is_right_padded: - warnings.warn( - f"Input to Megatron Generation worker is not properly right-padded: {error_msg}" + + dist_rank = torch.distributed.get_rank() + is_request_submitter = (dist_rank == 0) + + # For non-rank-0 workers, data may be None (they participate in engine loop only) + if data is not None: + # Verify input is right padded + assert isinstance(data, BatchedDataDict), ( + f"data must be a BatchedDataDict, got type: {type(data)}" + ) + assert "input_ids" in data and "input_lengths" in data, ( + f"input_ids and input_lengths must be present in the BatchedDataDict, got keys: {data.keys()}" + ) + is_right_padded, error_msg = verify_right_padding( + data, pad_value=self.tokenizer.pad_token_id + ) + if not is_right_padded: + warnings.warn( + f"Input to Megatron Generation worker is not properly right-padded: {error_msg}" + ) + + + mcore_generation_config = self.cfg["generation"]["mcore_generation_config"] + # Use inference_mode context manager (mirrors megatron_rl_inference_mode from Megatron RL) + # This handles: eval mode, CUDA graph toggle, engine init/resume, and cleanup + with torch.no_grad(), self.inference_mode(mcore_generation_config) as inference_engine: + # Handle None values for top_k - convert to integer as required by Megatron + top_k_cfg = self.cfg["generation"]["top_k"] + top_k_val = 1 if greedy else (int(top_k_cfg) if top_k_cfg is not None else 0) + + top_p_cfg = self.cfg["generation"]["top_p"] + top_p_val = ( + 0.0 if greedy else (float(top_p_cfg) if top_p_cfg is not None else 0.0) ) - model_cfg = self.megatron_cfg.model - inference_wrapper_config = InferenceWrapperConfig( - hidden_size=model_cfg.hidden_size, - inference_batch_times_seqlen_threshold=1000000, - fp32_residual_connection=model_cfg.fp32_residual_connection, - params_dtype=model_cfg.params_dtype, - padded_vocab_size=self.final_padded_vocab_size, # Use the potentially updated value - inference_max_seq_length=self.cfg["generation"]["max_new_tokens"], # type: ignore - inference_max_requests=self.cfg["generation_batch_size"], - ) - - from megatron.core.inference.contexts.dynamic_context import ( - DynamicInferenceContext, - ) - from megatron.core.inference.engines.dynamic_engine import ( - DynamicInferenceEngine, - ) - from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( - GPTInferenceWrapper, - ) - from megatron.core.inference.sampling_params import SamplingParams - - mcore_generation_config = cast( - MegatronGenerationConfig, self.cfg["generation"]["mcore_generation_config"] - ) - buffer_size_gb = mcore_generation_config["buffer_size_gb"] - - num_cuda_graphs = mcore_generation_config["num_cuda_graphs"] - block_size_tokens = mcore_generation_config["block_size_tokens"] - use_cuda_graphs_for_non_decode_steps = mcore_generation_config[ - "use_cuda_graphs_for_non_decode_steps" - ] - enable_chunked_prefill = mcore_generation_config["enable_chunked_prefill"] - unified_memory_level = mcore_generation_config["unified_memory_level"] - buffer_guaranteed_fraction = mcore_generation_config[ - "buffer_guaranteed_fraction" - ] - max_tokens = mcore_generation_config["max_tokens"] - - model_config = self.model.config - model_config.cuda_graph_impl = "local" - - dynamic_context = DynamicInferenceContext( - params_dtype=inference_wrapper_config.params_dtype, - num_layers=model_config.num_layers, - kv_channels=model_config.kv_channels, - num_attention_heads=model_config.num_query_groups, - max_sequence_length=self.cfg["generation"]["max_new_tokens"], - buffer_guaranteed_fraction=buffer_guaranteed_fraction, - buffer_size_gb=buffer_size_gb, - materialize_only_last_token_logits=False, - num_cuda_graphs=num_cuda_graphs, - block_size_tokens=block_size_tokens, - tensor_model_parallel_size=self.cfg["megatron_cfg"][ - "tensor_model_parallel_size" - ], - use_cuda_graphs_for_non_decode_steps=use_cuda_graphs_for_non_decode_steps, - use_flashinfer_fused_rope=False, - unified_memory_level=unified_memory_level, - max_tokens_override=max_tokens, - ) - inference_wrapped_model = GPTInferenceWrapper( - self.model, inference_wrapper_config, dynamic_context - ) - - inference_wrapped_model.prep_model_for_inference() - # Set pipeline parallel flag - inference_wrapped_model.model_is_pipeline_parallel = ( - self.cfg["megatron_cfg"]["pipeline_model_parallel_size"] > 1 - ) - - text_generation_controller = TextGenerationController( - inference_wrapped_model=inference_wrapped_model, - tokenizer=self.megatron_tokenizer, - ) - - # Calculate seed based on node and rank to ensure reproducibility across workers - local_rank = torch.cuda.current_device() # Local GPU index on the node - num_gpus_per_node = torch.cuda.device_count() - node_idx = self.rank // num_gpus_per_node if num_gpus_per_node > 0 else 0 - seed = (node_idx * 1024) + local_rank - - # New API: DynamicInferenceEngine has additional parameters - dynamic_engine = DynamicInferenceEngine( - text_generation_controller, - dynamic_context, - enable_cuda_graph=True, - random_seed=seed, - track_paused_request_events=False, - enable_chunked_prefill=enable_chunked_prefill, - inference_logging_step_interval=0, - ) - - # Handle None values for top_k - convert to integer as required by Megatron - top_k_cfg = self.cfg["generation"]["top_k"] - top_k_val = 1 if greedy else (int(top_k_cfg) if top_k_cfg is not None else 0) - - top_p_cfg = self.cfg["generation"]["top_p"] - top_p_val = ( - 0.0 if greedy else (float(top_p_cfg) if top_p_cfg is not None else 0.0) - ) - - # New API: SamplingParams now includes termination_id and uses num_tokens_total - sampling_params = SamplingParams( - temperature=self.cfg["generation"]["temperature"] if not greedy else 0, - top_k=top_k_val, - top_p=top_p_val, - skip_prompt_log_probs=False, - return_log_probs=True, - num_tokens_total=self.cfg["generation"]["max_new_tokens"], - num_tokens_to_generate=None, - termination_id=self.megatron_tokenizer.eod, - ) - - input_ids = data["input_ids"] - prompt_tokens_tensor = input_ids.cuda() - prompt_lengths_tensor = data["input_lengths"] - request_id = 0 + sampling_params = SamplingParams( + temperature=self.cfg["generation"]["temperature"] if not greedy else 0, + top_k=top_k_val, + top_p=top_p_val, + skip_prompt_log_probs=False, + return_log_probs=True, + num_tokens_total=self.cfg["generation"]["max_new_tokens"], + num_tokens_to_generate=None, + termination_id=self.megatron_tokenizer.eod, + ) - # New API: add_request now takes sampling_params as a parameter - for p, prompt_len in zip( - prompt_tokens_tensor, prompt_lengths_tensor, strict=True - ): - dynamic_engine.add_request( - request_id, - p[:prompt_len], - sampling_params=sampling_params, + # Only rank 0 has actual data to submit + if is_request_submitter: + input_ids = data["input_ids"] + print(f"GPU : {torch.cuda.current_device()}, input_ids: {input_ids.shape}") + prompt_tokens_tensor = input_ids.cuda() + prompt_lengths_tensor = data["input_lengths"] + else: + # Non-submitter ranks: create empty tensors (will not be used for submission) + print(f"GPU : {torch.cuda.current_device()}, participating in engine loop (no data to submit)") + prompt_tokens_tensor = torch.empty(0, dtype=torch.long, device="cuda") + prompt_lengths_tensor = torch.empty(0, dtype=torch.long, device="cuda") + + # Run the coordinator-based generation using the persistent engine + # Rank 0 submits requests, other ranks participate in engine loop + # Results are broadcast to all ranks inside this method + result = self._run_async_generation_with_persistent_engine( + prompt_tokens_tensor, + prompt_lengths_tensor, + sampling_params, ) - request_id += 1 - result = [] - while dynamic_engine.has_unfinished_requests(): - result_step = dynamic_engine.step_modern(verbose=False) - finished_requests = result_step.get("finished_requests", []) - for finished_request in finished_requests: - result.append(finished_request) + self.model.config.flash_decode = False - # Sort results by request_id to maintain original batch order - result.sort(key=lambda x: x.request_id) + # Context manager has exited - CUDA graphs are now disabled, model is back in train mode - out = { - "tokens": [x.prompt_tokens.tolist() + x.generated_tokens for x in result], - "logprobs": [x.prompt_log_probs + x.generated_log_probs for x in result], - } + # Only rank 0 needs to format and return results + # Other ranks return None (their results are ignored due to output_is_replicated) + if not is_request_submitter: + # Return empty result for non-submitter ranks + # Use BatchedDataDict directly instead of from_batches to avoid padding issues with empty tensors + return BatchedDataDict({ + "output_ids": torch.empty(0, 0, dtype=torch.long), + "logprobs": torch.empty(0, 0, dtype=torch.float), + "generation_lengths": torch.empty(0, dtype=torch.long), + "unpadded_sequence_lengths": torch.empty(0, dtype=torch.long), + }).to("cpu") input_lengths = data["input_lengths"] - # pad the out "tokens" and "logprobs" and make them into tensors from lists batch_size = data["input_ids"].size(0) max_gen_seq_len = max([len(x.generated_tokens) for x in result]) padded_input_length = input_ids.size(1) max_seq_len = padded_input_length + max_gen_seq_len - # Create padded tensors for tokens and logprobs output_ids_padded = torch.full( (batch_size, max_seq_len), self.tokenizer.pad_token_id, @@ -2028,7 +2362,6 @@ def generate( device=data["input_ids"].device, ) - # Fill in the padded tensors with actual values generation_lengths = torch.zeros( batch_size, dtype=torch.long, device=data["input_ids"].device ) @@ -2036,15 +2369,17 @@ def generate( batch_size, dtype=torch.long, device=data["input_ids"].device ) for i in range(batch_size): - seq_len = len(out["tokens"][i]) + tokens = result[i].prompt_tokens.tolist() + result[i].generated_tokens + logprobs = result[i].prompt_log_probs + result[i].generated_log_probs + seq_len = len(tokens) output_ids_padded[i, :seq_len] = torch.tensor( - out["tokens"][i], dtype=torch.long, device=data["input_ids"].device + tokens, dtype=torch.long, device=data["input_ids"].device ) generation_lengths[i] = seq_len - input_lengths[i].item() unpadded_sequence_lengths[i] = seq_len - logprob_len = len(out["logprobs"][i]) + logprob_len = len(logprobs) logprobs_padded[i, 1 : logprob_len + 1] = torch.tensor( - out["logprobs"][i], + logprobs, dtype=torch.float, device=data["input_ids"].device, ) @@ -2056,11 +2391,149 @@ def generate( "unpadded_sequence_lengths": unpadded_sequence_lengths, } - self.model.config.flash_decode = False - no_grad.__exit__(None, None, None) - return BatchedDataDict.from_batches([out_dict]).to("cpu") + def _start_inference_loop_thread(self): + """Start a background thread with a persistent event loop for inference. + + This thread runs the event loop that hosts the engine loop task. + The loop runs forever until explicitly stopped. + """ + import threading + + def run_loop(): + asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) + self._inference_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._inference_loop) + # Run forever - the engine loop task will run in this loop + self._inference_loop.run_forever() + + self._inference_thread = threading.Thread(target=run_loop, daemon=True) + self._inference_thread.start() + + # Wait for the loop to be created + while self._inference_loop is None: + time.sleep(0.001) + + def _run_async_coordinator_start(self, coordinator_port: int): + """Start the coordinator and engine loop in the background thread. + + This is called once during the first generate() call to initialize + the persistent inference infrastructure. + """ + import concurrent.futures + + # Start the background thread with the event loop if not already running + if self._inference_loop is None: + self._start_inference_loop_thread() + + # Schedule the coordinator start in the inference loop + future = asyncio.run_coroutine_threadsafe( + self._start_inference_coordinator(coordinator_port), + self._inference_loop + ) + # Wait for completion + return future.result() + + def _run_async_generation_with_persistent_engine( + self, + prompt_tokens_tensor: torch.Tensor, + prompt_lengths_tensor: torch.Tensor, + sampling_params: "SamplingParams", + ) -> list: + """Run generation using the persistent inference engine. + + This method uses the pre-initialized engine and client to run generation. + Unlike the original method, it doesn't start/stop the coordinator each time. + The async operation runs in the persistent inference loop. + """ + if self._inference_loop is None: + raise RuntimeError("Inference loop not initialized. Call generate() first.") + + # Schedule the generation in the inference loop + future = asyncio.run_coroutine_threadsafe( + self._generate_with_persistent_engine( + prompt_tokens_tensor, + prompt_lengths_tensor, + sampling_params, + ), + self._inference_loop + ) + # Wait for completion and return the result + return future.result() + + async def _generate_with_persistent_engine( + self, + prompt_tokens_tensor: torch.Tensor, + prompt_lengths_tensor: torch.Tensor, + sampling_params: "SamplingParams", + ) -> list: + """Run generation using the persistent coordinator-based inference. + + This method uses the already-running engine and submits requests through + the persistent client. The engine loop continues running between calls. + + For coordinator-based inference with centralized request submission: + - Only rank 0 (the request submitter) submits requests and collects results + - Other ranks return early but their engine loops continue running in the + background, processing requests distributed by the coordinator + - No broadcast is needed since only rank 0's results are used by the caller + + Args: + prompt_tokens_tensor: Tensor of prompt token IDs [batch_size, seq_len] + prompt_lengths_tensor: Tensor of prompt lengths [batch_size] + sampling_params: Sampling parameters for generation + + Returns: + List of completed request records sorted by request_id (rank 0), + or empty list (other ranks) + """ + from megatron.core.inference.inference_request import ( + DynamicInferenceRequest, + DynamicInferenceRequestRecord, + ) + + dist_rank = torch.distributed.get_rank() + + if dist_rank == 0: + assert self.inference_client is not None, "Inference client not initialized" + + # Non-rank-0 workers: return immediately with empty results + # Their engine loops will continue processing requests from the coordinator + # in the background (the engine loop runs as a separate task in _inference_loop) + if dist_rank != 0: + print(f"[Rank {dist_rank}] Participating in engine loop only (not submitting requests)") + # Return empty results - the caller only uses rank 0's results + return [] + + # Rank 0: submit ALL requests and collect results + print(f"[Rank {dist_rank}] Submitting {prompt_tokens_tensor.size(0)} requests to coordinator") + + futures = [] + for request_id, (prompt_tokens, prompt_len) in enumerate( + zip(prompt_tokens_tensor, prompt_lengths_tensor, strict=True) + ): + # Extract the actual prompt tokens (without padding) and convert to list + prompt = prompt_tokens[: prompt_len.item()].tolist() + future = self.inference_client.add_request(prompt, sampling_params) + futures.append(future) + + # Wait for all requests to complete + # The coordinator distributes work to all DP engines, including this one + completed_records: list[DynamicInferenceRequestRecord] = await asyncio.gather( + *futures + ) + + # Extract the merged request from each record + results = [record.merge() for record in completed_records] + + # Sort by request_id to maintain original batch order + results.sort(key=lambda x: x.request_id) + + print(f"[Rank {dist_rank}] Completed {len(results)} requests") + + return results + @torch.no_grad() @wrap_with_nvtx_name("megatron_policy_worker/prepare_refit_info") def prepare_refit_info(self) -> None: @@ -2218,6 +2691,7 @@ def broadcast_weights_for_collective( ) def prepare_for_lp_inference(self): + self.model = self.move_model(self.model, "cuda", move_grads=False) self.model.eval() @@ -2751,4 +3225,4 @@ def re_enable_float32_expert_bias(self) -> None: if router is not None and hasattr( router, "_maintain_float32_expert_bias" ): - router._maintain_float32_expert_bias() + router._maintain_float32_expert_bias() \ No newline at end of file diff --git a/uv.lock b/uv.lock index 5818765dad..567875dedc 100644 --- a/uv.lock +++ b/uv.lock @@ -2974,6 +2974,7 @@ wheels = [ name = "megatron-bridge" source = { editable = "3rdparty/Megatron-Bridge-workspace" } dependencies = [ + { name = "accelerate" }, { name = "causal-conv1d" }, { name = "datasets" }, { name = "flash-linear-attention" }, @@ -2999,6 +3000,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "accelerate" }, { name = "causal-conv1d", git = "https://github.com/Dao-AILab/causal-conv1d?rev=67e0a9dfe1518fc0036444e9ab5fe06ab78299e0" }, { name = "datasets" }, { name = "flash-linear-attention" }, @@ -3016,8 +3018,8 @@ requires-dist = [ { name = "tensorboard", specifier = ">=2.19.0" }, { name = "timm" }, { name = "tqdm", specifier = ">=4.67.1" }, - { name = "transformer-engine", extras = ["pytorch"], specifier = ">=2.9.0a0,<2.10.0" }, - { name = "transformers", specifier = ">=4.57.1" }, + { name = "transformer-engine", extras = ["pytorch"], specifier = ">=2.10.0a0,<2.12.0" }, + { name = "transformers", specifier = "==4.57.1" }, { name = "typing-extensions" }, { name = "wandb", specifier = ">=0.19.10" }, ] @@ -3028,9 +3030,10 @@ source = { editable = "3rdparty/Megatron-LM-workspace" } dependencies = [ { name = "av" }, { name = "causal-conv1d" }, + { name = "datasets" }, { name = "einops" }, { name = "emerging-optimizers" }, - { name = "flash-linear-attention" }, + { name = "fastapi" }, { name = "flashinfer-python" }, { name = "mamba-ssm" }, { name = "megatron-energon", extra = ["av-decode"] }, @@ -3043,7 +3046,6 @@ dependencies = [ { name = "onnxscript" }, { name = "opentelemetry-api" }, { name = "packaging" }, - { name = "setuptools" }, { name = "tensorstore", version = "0.1.74", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, { name = "tensorstore", version = "0.1.76", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.13'" }, { name = "torch", version = "2.9.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, @@ -3055,29 +3057,29 @@ dependencies = [ [package.metadata] requires-dist = [ - { name = "av", specifier = "<16.0.0" }, + { name = "av" }, { name = "causal-conv1d", git = "https://github.com/Dao-AILab/causal-conv1d?rev=67e0a9dfe1518fc0036444e9ab5fe06ab78299e0" }, + { name = "datasets" }, { name = "einops", specifier = "~=0.8" }, { name = "emerging-optimizers", git = "https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git?rev=v0.1.0" }, - { name = "flash-linear-attention", specifier = "~=0.3.2" }, - { name = "flashinfer-python" }, + { name = "fastapi", specifier = "~=0.50" }, + { name = "flashinfer-python", specifier = "~=0.5.0" }, { name = "mamba-ssm", git = "https://github.com/state-spaces/mamba.git?rev=d68d16ed7d5d5164eb5a57c0285f3b7eb8394ec1" }, { name = "megatron-energon", extras = ["av-decode"], specifier = "~=6.0" }, { name = "multi-storage-client", specifier = "~=0.27" }, - { name = "numpy", specifier = "<2.0.0" }, + { name = "numpy" }, { name = "nv-grouped-gemm", git = "https://github.com/fanshiqing/grouped_gemm?tag=v1.1.4.post7" }, - { name = "nvidia-modelopt", extras = ["torch"], marker = "sys_platform != 'darwin'", specifier = ">=0.33.0a0,<0.34.0" }, - { name = "nvidia-resiliency-ext", specifier = ">=0.4.0a0,<0.5.0" }, + { name = "nvidia-modelopt", extras = ["torch"], marker = "sys_platform != 'darwin'" }, + { name = "nvidia-resiliency-ext" }, { name = "nvtx", specifier = "~=0.2" }, { name = "onnxscript" }, { name = "opentelemetry-api", specifier = "~=1.33.1" }, { name = "packaging", specifier = ">=24.2" }, - { name = "setuptools", specifier = "<80.0.0" }, { name = "tensorstore", specifier = "~=0.1,!=0.1.46,!=0.1.72" }, { name = "torch", marker = "sys_platform != 'darwin'", index = "https://download.pytorch.org/whl/cu129" }, { name = "torch", marker = "sys_platform == 'darwin'", index = "https://pypi.org/simple" }, { name = "tqdm" }, - { name = "transformer-engine", extras = ["pytorch"], specifier = ">=2.9.0a0,<2.10.0" }, + { name = "transformer-engine", extras = ["core-cu13", "pytorch"], specifier = ">=2.9.0a0,<2.12.0" }, { name = "wget" }, ]