Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ actor_rollout_ref:
skip_dump_dir: /tmp/rollout_dump
skip_tokenizer_init: true
enable_rollout_routing_replay: false
enable_checkpoint_engine: false
profiler:
_target_: verl.utils.profiler.ProfilerConfig
tool: ${oc.select:global_profiler.tool,null}
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ actor_rollout_ref:
skip_dump_dir: /tmp/rollout_dump
skip_tokenizer_init: true
enable_rollout_routing_replay: false
enable_checkpoint_engine: false
profiler:
_target_: verl.utils.profiler.ProfilerConfig
tool: ${oc.select:global_profiler.tool,null}
Expand Down
3 changes: 3 additions & 0 deletions verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,9 @@ skip_tokenizer_init: True
# When enabled (True), the rollout will record the routing decisions.
enable_rollout_routing_replay: False

# Whether to checkpoint_engine for update weights
# When enabled (True), parameters sync between trainer and rollout through checkpoint_engine.
enable_checkpoint_engine: False

# profile the rollout model in `generate_sequence`
profiler:
Expand Down
2 changes: 2 additions & 0 deletions verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ class RolloutConfig(BaseConfig):

enable_rollout_routing_replay: bool = False

enable_checkpoint_engine: bool = False

def __post_init__(self):
"""Validate the rollout config"""
if self.expert_parallel_size > 1:
Expand Down
54 changes: 46 additions & 8 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import logging
import os
import warnings
from collections.abc import Callable
from dataclasses import asdict
from typing import Any, Optional
from typing import Any, Generator, Optional

import numpy as np
import psutil
Expand Down Expand Up @@ -577,6 +578,24 @@ def _build_model_optimizer(

return actor_module_fsdp, actor_optimizer, actor_lr_scheduler, actor_model_config

def update_weighs_by_checkpoint_engine(
self,
weights: Generator[tuple[str, torch.Tensor], None, None],
req_func: Callable[[list[tuple[str, str]]], None],
):
named_tensors = {}
for tensor_idx, (name, tensor) in enumerate(weights):
if tensor_idx % self.world_size == self.rank:
named_tensors[name] = tensor

checkpoint_name = "checkpoint_engine"
self.parameter_server.register_checkpoint(checkpoint_name, named_tensors=named_tensors)
named_tensors = {}
dist.barrier()
self.parameter_server.gather_metas(checkpoint_name)
self.parameter_server.update(checkpoint_name, req_func)
self.parameter_server.unregister_checkpoint(checkpoint_name)

def _build_rollout(self, trust_remote_code=False):
from torch.distributed.device_mesh import init_device_mesh

Expand All @@ -588,10 +607,10 @@ def _build_rollout(self, trust_remote_code=False):
# 2. build rollout device mesh
infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size
infer_pp = self.config.rollout.pipeline_model_parallel_size
infer_world_size = infer_tp * infer_pp
dp = self.world_size // infer_world_size
assert self.world_size % infer_world_size == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}"
self.infer_world_size = infer_tp * infer_pp
dp = self.world_size // self.infer_world_size
assert self.world_size % self.infer_world_size == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {self.infer_world_size}"
)
rollout_device_mesh = init_device_mesh(
device_name, mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"]
Expand Down Expand Up @@ -700,10 +719,14 @@ async def rollout_mode(self):

set_expandable_segments(False)

if self.config.rollout.enable_checkpoint_engine:
device = "cpu"
else:
device = get_device_id() # used when fsdp2 set cpu_offload_policy

if peft_config is not None and self.base_sync_done:
per_tensor_param = params.items() if isinstance(params, dict) else params # Fixed: handle dict case
else:
device = get_device_id() # used when fsdp2 set cpu_offload_policy
per_tensor_param = (
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
for name, param in params.items()
Expand All @@ -718,10 +741,20 @@ async def rollout_mode(self):
(name, param.to(device, non_blocking=True).full_tensor() if isinstance(param, DTensor) else param)
for name, param in base_model_params.items()
)
await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False)
if self.config.rollout.enable_checkpoint_engine:
req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size)
self.update_weighs_by_checkpoint_engine(per_tensor_base_params, req_func)
else:
await self.rollout.update_weights(per_tensor_base_params, base_sync_done=False)
del base_model_params, per_tensor_base_params

await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
if self.config.rollout.enable_checkpoint_engine:
req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size)
self.update_weighs_by_checkpoint_engine(per_tensor_param, req_func)
else:
await self.rollout.update_weights(
per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done
)
log_gpu_memory_usage("After update_weights", logger=logger)
del params, per_tensor_param
aggressive_empty_cache(force_sync=True)
Expand Down Expand Up @@ -863,6 +896,11 @@ def init_model(self):
checkpoint_config=checkpoint_contents,
)

if self.config.rollout.enable_checkpoint_engine:
from checkpoint_engine.ps import ParameterServer

self.parameter_server = ParameterServer(auto_pg=False)

@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="actor"))
@DistProfiler.annotate(color="red", role="actor_update")
def update_actor(self, data: DataProto):
Expand Down
41 changes: 35 additions & 6 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import logging
import os
import time
from typing import Any, Optional
from collections.abc import Callable
from typing import Any, Generator, Optional

import psutil
import torch
Expand Down Expand Up @@ -483,6 +484,24 @@ def _build_model_optimizer(

return actor_module, actor_optimizer, actor_optimizer_scheduler, self.hf_config, optim_config

def update_weighs_by_checkpoint_engine(
self,
weights: Generator[tuple[str, torch.Tensor], None, None],
req_func: Callable[[list[tuple[str, str]]], None],
):
named_tensors = {}
for tensor_idx, (name, tensor) in enumerate(weights):
if tensor_idx % self.world_size == self.rank:
named_tensors[name] = tensor.to("cpu", non_blocking=True)

checkpoint_name = "checkpoint_engine"
self.parameter_server.register_checkpoint(checkpoint_name, named_tensors=named_tensors)
named_tensors = {}
torch.distributed.barrier()
self.parameter_server.gather_metas(checkpoint_name)
self.parameter_server.update(checkpoint_name, req_func)
self.parameter_server.unregister_checkpoint(checkpoint_name)

def _build_rollout(self, trust_remote_code=False):
from torch.distributed.device_mesh import init_device_mesh

Expand All @@ -500,10 +519,10 @@ def _build_rollout(self, trust_remote_code=False):
# 2. build rollout device mesh
infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size
infer_pp = self.config.rollout.pipeline_model_parallel_size
infer_world_size = infer_tp * infer_pp
dp = self.world_size // infer_world_size
assert self.world_size % infer_world_size == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {infer_world_size}"
self.infer_world_size = infer_tp * infer_pp
dp = self.world_size // self.infer_world_size
assert self.world_size % self.infer_world_size == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_world_size: {self.infer_world_size}"
)
rollout_device_mesh = init_device_mesh(
get_device_name(), mesh_shape=(dp, infer_tp, infer_pp), mesh_dim_names=["dp", "infer_tp", "infer_pp"]
Expand Down Expand Up @@ -661,6 +680,11 @@ def init_model(self):
if not self.config.actor.megatron.use_mbridge:
self.weight_converter = get_mcore_weight_converter(self.actor_model_config, self.dtype)

if self.config.rollout.enable_checkpoint_engine:
from checkpoint_engine.ps import ParameterServer

self.parameter_server = ParameterServer(auto_pg=False)

get_torch_device().empty_cache()
log_gpu_memory_usage("After init_model finish", logger=logger)

Expand Down Expand Up @@ -689,7 +713,12 @@ async def rollout_mode(self):

if self.config.rollout.free_cache_engine:
await self.rollout.resume(tags=["weights"])
await self.rollout.update_weights(per_tensor_param)

if self.config.rollout.enable_checkpoint_engine:
req_func = await self.rollout.checkpoint_engine_req_func(self.infer_world_size)
self.update_weighs_by_checkpoint_engine(per_tensor_param, req_func)
else:
await self.rollout.update_weights(per_tensor_param)
if self._is_offload_param:
offload_megatron_model_to_cpu(self.actor.actor_module)
aggressive_empty_cache(force_sync=True)
Expand Down
13 changes: 10 additions & 3 deletions verl/workers/rollout/sglang_rollout/async_sglang_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ def __init__(
cuda_visible_devices: str,
):
print(f"SGLang http server: {rollout_mode=}, {replica_rank=}, {node_rank=}, {nnodes=}, {cuda_visible_devices=}")
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
assert torch.cuda.is_available(), "SGLang http server should run on GPU node"
os.environ["CUDA_VISIBLE_DEVICES" if torch.cuda.is_available() else "ASCEND_RT_VISIBLE_DEVICES"] = (
cuda_visible_devices
)
assert torch.cuda.is_available() or torch.npu.is_available(), "SGLang http server should run on GPU/NPU node"

self.config: RolloutConfig = omega_conf_to_dataclass(config)
self.model_config: HFModelConfig = omega_conf_to_dataclass(model_config, dataclass_type=HFModelConfig)
Expand Down Expand Up @@ -337,7 +339,12 @@ async def launch_servers(self):
node_id=node_id,
soft=False,
),
runtime_env={"env_vars": {"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1"}},
runtime_env={
"env_vars": {
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1",
"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1",
}
},
name=name,
).remote(
config=self.config,
Expand Down
13 changes: 13 additions & 0 deletions verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
import logging
import multiprocessing as mp
import os
from collections.abc import Callable
from typing import Generator

import ray
import sglang.srt.entrypoints.engine
import torch
from sglang.srt.checkpoint_engine.update import req_inference
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
assert_pkg_version,
Expand Down Expand Up @@ -191,3 +193,14 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None

if self.device_mesh["infer_tp"].get_local_rank() == 0:
await self._engine.flush_cache()

async def checkpoint_engine_req_func(self, inference_parallel_size: int) -> Callable[[list[tuple[str, str]]], None]:
if self.device_mesh["infer_tp"].get_local_rank() == 0:
await self._init_server_adapter()
endpoint = f"http://{self._engine.server_args.host}:{self._engine.server_args.port}"
else:
endpoint = ""

req_func = req_inference(endpoint=endpoint, inference_parallel_size=inference_parallel_size)

return req_func
6 changes: 6 additions & 0 deletions verl/workers/rollout/vllm_rollout/vllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
"override_generation_config": json.dumps(override_generation_config),
"quantization": quantization,
"hf_overrides": {"quantization_config": fp8_block_quant_kwargs} if quantization == "fp8" else None,
"worker_extension_cls": "checkpoint_engine.worker.VllmColocateWorkerExtension"
if self.config.enable_checkpoint_engine
else None,
**engine_kwargs,
}

Expand Down Expand Up @@ -690,6 +693,9 @@ async def launch_servers(self):
soft=False,
),
name=name,
runtime_env={"env_vars": {"VLLM_SERVER_DEV_MODE": "1"}}
if self.config.enable_checkpoint_engine
else None,
).remote(
config=self.config,
model_config=self.model_config,
Expand Down
11 changes: 11 additions & 0 deletions verl/workers/rollout/vllm_rollout/vllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import getpass
import logging
import os
from collections.abc import Callable
from dataclasses import asdict
from types import MethodType
from typing import Any, Generator
Expand Down Expand Up @@ -133,6 +134,13 @@ def __init__(
else:
self.sleep_level = VLLM_SLEEP_LEVEL

rank = int(os.environ["RANK"])
local_world_size = int(os.environ["RAY_LOCAL_WORLD_SIZE"])
Comment on lines +137 to +138
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Accessing environment variables using os.environ["KEY"] is unsafe as it will raise a KeyError if the variable is not set, causing the worker to crash. It's much safer to use os.getenv("KEY", default_value).

This is a critical issue that can lead to runtime crashes if the environment is not perfectly configured.

Suggested change
rank = int(os.environ["RANK"])
local_world_size = int(os.environ["RAY_LOCAL_WORLD_SIZE"])
rank = int(os.getenv("RANK", "0"))
local_world_size = int(os.getenv("RAY_LOCAL_WORLD_SIZE", "1"))

rollout_world_size = self.config.tensor_model_parallel_size * self.config.data_parallel_size
self.replica_rank = rank // rollout_world_size
self.rollout_rank = rank % rollout_world_size
self.node_rank = self.rollout_rank // local_world_size

def _init_zeromq(self) -> str:
tensor_parallel_size = self.config.tensor_model_parallel_size

Expand Down Expand Up @@ -262,6 +270,9 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
logger.info("Loading standard weights (non-FP8, async)")
model.load_weights(weights)

async def checkpoint_engine_req_func(self, inference_parallel_size: int) -> Callable[[list[tuple[str, str]]], None]:
raise NotImplementedError

def generate_sequences(self, prompts: DataProto) -> DataProto:
"""Batch generate sequences in sync mode."""
raise NotImplementedError
Expand Down