diff --git a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml index a117c0f332f..02aee6fc551 100644 --- a/verl/trainer/config/_generated_ppo_megatron_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_megatron_trainer.yaml @@ -272,6 +272,7 @@ actor_rollout_ref: skip_dump_dir: /tmp/rollout_dump skip_tokenizer_init: true enable_rollout_routing_replay: false + enable_checkpoint_engine: false profiler: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index 833ebb70d5b..14428e7bb41 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -261,6 +261,7 @@ actor_rollout_ref: skip_dump_dir: /tmp/rollout_dump skip_tokenizer_init: true enable_rollout_routing_replay: false + enable_checkpoint_engine: false profiler: _target_: verl.utils.profiler.ProfilerConfig tool: ${oc.select:global_profiler.tool,null} diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index 968d9e11277..b2a01111705 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -281,6 +281,9 @@ skip_tokenizer_init: True # When enabled (True), the rollout will record the routing decisions. enable_rollout_routing_replay: False +# Whether to checkpoint_engine for update weights +# When enabled (True), parameters sync between trainer and rollout through checkpoint_engine. +enable_checkpoint_engine: False # profile the rollout model in `generate_sequence` profiler: diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index bea1bd4520d..b0367a8cbe9 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -207,6 +207,8 @@ class RolloutConfig(BaseConfig): enable_rollout_routing_replay: bool = False + enable_checkpoint_engine: bool = False + def __post_init__(self): """Validate the rollout config""" if self.expert_parallel_size > 1: diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index b7d89134d72..58648463e0d 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -20,8 +20,9 @@ import logging import os import warnings +from collections.abc import Callable from dataclasses import asdict -from typing import Any, Optional +from typing import Any, Generator, Optional import numpy as np import psutil @@ -577,6 +578,24 @@ def _build_model_optimizer( return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config + def update_weighs_by_checkpoint_engine( + self, + weights: Generator[tuple[str, torch.Tensor], None, None], + req_func: Callable[[list[tuple[str, str]]], None], + ): + named_tensors = {} + for tensor_idx, (name, tensor) in enumerate(weights): + if tensor_idx % self.world_size == self.rank: + named_tensors[name] = tensor + + checkpoint_name = "checkpoint_engine" + self.parameter_server.register_checkpoint(checkpoint_name, named_tensors=named_tensors) + named_tensors = {} + dist.barrier() + self.parameter_server.gather_metas(checkpoint_name) + self.parameter_server.update(checkpoint_name, req_func) + self.parameter_server.unregister_checkpoint(checkpoint_name) + def _build_rollout(self, trust_remote_code=False): from torch.distributed.device_mesh import init_device_mesh @@ -588,10 +607,10 @@ def _build_rollout(self, trust_remote_code=False): # 2. build rollout device mesh infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size infer_pp = self.config.rollout.pipeline_model_parallel_size - infer_world_size = infer_tp * infer_pp - dp = self.world_size // infer_world_size - assert self.world_size % infer_world_size == 0, ( - f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}" + self.infer_world_size = infer_tp * infer_pp + dp = self.world_size // self.infer_world_size + assert self.world_size % self.infer_world_size == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {self.infer_world_size}" ) rollout_device_mesh = init_device_mesh( device_name, mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"] @@ -700,10 +719,14 @@ async def rollout_mode(self): set_expandable_segments(False) + if self.config.rollout.enable_checkpoint_engine: + device = "cpu" + else: + device = get_device_id() # used when fsdp2 set cpu_offload_policy + if peft_config is not None and self.base_sync_done: per_tensor_param = params.items() if isinstance(params, dict) else params # Fixed: handle dict case else: - device = get_device_id() # used when fsdp2 set cpu_offload_policy per_tensor_param = ( (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in params.items() @@ -718,10 +741,20 @@ async def rollout_mode(self): (name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param) for name, param in base_model_params.items() ) - await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False) + if self.config.rollout.enable_checkpoint_engine: + req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size) + self.update_weighs_by_checkpoint_engine(per_tensor_base_params, req_func) + else: + await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False) del base_model_params, per_tensor_base_params - await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done) + if self.config.rollout.enable_checkpoint_engine: + req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size) + self.update_weighs_by_checkpoint_engine(per_tensor_param, req_func) + else: + await self.rollout.update_weights( + per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done + ) log_gpu_memory_usage("After update_weights", logger=logger) del params, per_tensor_param aggressive_empty_cache(force_sync=True) @@ -863,6 +896,11 @@ def init_model(self): checkpoint_config=checkpoint_contents, ) + if self.config.rollout.enable_checkpoint_engine: + from checkpoint_engine.ps import ParameterServer + + self.parameter_server = ParameterServer(auto_pg=False) + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor")) @DistProfiler.annotate(color="red", role="actor_update") def update_actor(self, data: DataProto): diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index db2e3fb1b97..233d77d8963 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -19,7 +19,8 @@ import logging import os import time -from typing import Any, Optional +from collections.abc import Callable +from typing import Any, Generator, Optional import psutil import torch @@ -483,6 +484,24 @@ def _build_model_optimizer( return actor_module, actor_optimizer, actor_optimizer_scheduler, self.hf_config, optim_config + def update_weighs_by_checkpoint_engine( + self, + weights: Generator[tuple[str, torch.Tensor], None, None], + req_func: Callable[[list[tuple[str, str]]], None], + ): + named_tensors = {} + for tensor_idx, (name, tensor) in enumerate(weights): + if tensor_idx % self.world_size == self.rank: + named_tensors[name] = tensor.to("cpu", non_blocking=True) + + checkpoint_name = "checkpoint_engine" + self.parameter_server.register_checkpoint(checkpoint_name, named_tensors=named_tensors) + named_tensors = {} + torch.distributed.barrier() + self.parameter_server.gather_metas(checkpoint_name) + self.parameter_server.update(checkpoint_name, req_func) + self.parameter_server.unregister_checkpoint(checkpoint_name) + def _build_rollout(self, trust_remote_code=False): from torch.distributed.device_mesh import init_device_mesh @@ -500,10 +519,10 @@ def _build_rollout(self, trust_remote_code=False): # 2. build rollout device mesh infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size infer_pp = self.config.rollout.pipeline_model_parallel_size - infer_world_size = infer_tp * infer_pp - dp = self.world_size // infer_world_size - assert self.world_size % infer_world_size == 0, ( - f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}" + self.infer_world_size = infer_tp * infer_pp + dp = self.world_size // self.infer_world_size + assert self.world_size % self.infer_world_size == 0, ( + f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {self.infer_world_size}" ) rollout_device_mesh = init_device_mesh( get_device_name(), mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"] @@ -661,6 +680,11 @@ def init_model(self): if not self.config.actor.megatron.use_mbridge: self.weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype) + if self.config.rollout.enable_checkpoint_engine: + from checkpoint_engine.ps import ParameterServer + + self.parameter_server = ParameterServer(auto_pg=False) + get_torch_device().empty_cache() log_gpu_memory_usage("After init_model finish", logger=logger) @@ -689,7 +713,12 @@ async def rollout_mode(self): if self.config.rollout.free_cache_engine: await self.rollout.resume(tags=["weights"]) - await self.rollout.update_weights(per_tensor_param) + + if self.config.rollout.enable_checkpoint_engine: + req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size) + self.update_weighs_by_checkpoint_engine(per_tensor_param, req_func) + else: + await self.rollout.update_weights(per_tensor_param) if self._is_offload_param: offload_megatron_model_to_cpu(self.actor.actor_module) aggressive_empty_cache(force_sync=True) diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_server.py b/verl/workers/rollout/sglang_rollout/async_sglang_server.py index e78700d9f7a..dd00280f696 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_server.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_server.py @@ -77,8 +77,10 @@ def __init__( cuda_visible_devices: str, ): print(f"SGLang http server: {rollout_mode=}, {replica_rank=}, {node_rank=}, {nnodes=}, {cuda_visible_devices=}") - os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices - assert torch.cuda.is_available(), "SGLang http server should run on GPU node" + os.environ["CUDA_VISIBLE_DEVICES" if torch.cuda.is_available() else "ASCEND_RT_VISIBLE_DEVICES"] = ( + cuda_visible_devices + ) + assert torch.cuda.is_available() or torch.npu.is_available(), "SGLang http server should run on GPU/NPU node" self.config: RolloutConfig = omega_conf_to_dataclass(config) self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig) @@ -337,7 +339,12 @@ async def launch_servers(self): node_id=node_id, soft=False, ), - runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}}, + runtime_env={ + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", + } + }, name=name, ).remote( config=self.config, diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index 63d3b0c36af..34fe5b274cb 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -18,11 +18,13 @@ import logging import multiprocessing as mp import os +from collections.abc import Callable from typing import Generator import ray import sglang.srt.entrypoints.engine import torch +from sglang.srt.checkpoint_engine.update import req_inference from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( assert_pkg_version, @@ -191,3 +193,14 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None if self.device_mesh["infer_tp"].get_local_rank() == 0: await self._engine.flush_cache() + + async def checkpoint_engine_req_func(self, inference_parallel_size: int) -> Callable[[list[tuple[str, str]]], None]: + if self.device_mesh["infer_tp"].get_local_rank() == 0: + await self._init_server_adapter() + endpoint = f"http://{self._engine.server_args.host}:{self._engine.server_args.port}" + else: + endpoint = "" + + req_func = req_inference(endpoint=endpoint, inference_parallel_size=inference_parallel_size) + + return req_func diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index e0c292b8397..8c2fa30b11f 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -290,6 +290,9 @@ async def launch_server(self, master_address: str = None, master_port: int = Non "override_generation_config": json.dumps(override_generation_config), "quantization": quantization, "hf_overrides": {"quantization_config": fp8_block_quant_kwargs} if quantization == "fp8" else None, + "worker_extension_cls": "checkpoint_engine.worker.VllmColocateWorkerExtension" + if self.config.enable_checkpoint_engine + else None, **engine_kwargs, } @@ -690,6 +693,9 @@ async def launch_servers(self): soft=False, ), name=name, + runtime_env={"env_vars": {"VLLM_SERVER_DEV_MODE": "1"}} + if self.config.enable_checkpoint_engine + else None, ).remote( config=self.config, model_config=self.model_config, diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 42a1cd96885..8161880a83d 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -30,6 +30,7 @@ import getpass import logging import os +from collections.abc import Callable from dataclasses import asdict from types import MethodType from typing import Any, Generator @@ -133,6 +134,13 @@ def __init__( else: self.sleep_level = VLLM_SLEEP_LEVEL + rank = int(os.environ["RANK"]) + local_world_size = int(os.environ["RAY_LOCAL_WORLD_SIZE"]) + rollout_world_size = self.config.tensor_model_parallel_size * self.config.data_parallel_size + self.replica_rank = rank // rollout_world_size + self.rollout_rank = rank % rollout_world_size + self.node_rank = self.rollout_rank // local_world_size + def _init_zeromq(self) -> str: tensor_parallel_size = self.config.tensor_model_parallel_size @@ -262,6 +270,9 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None logger.info("Loading standard weights (non-FP8, async)") model.load_weights(weights) + async def checkpoint_engine_req_func(self, inference_parallel_size: int) -> Callable[[list[tuple[str, str]]], None]: + raise NotImplementedError + def generate_sequences(self, prompts: DataProto) -> DataProto: """Batch generate sequences in sync mode.""" raise NotImplementedError