Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions recipe/fully_async_policy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a
| `async_training.checkpoint_engine.enable`| Whether to use checkpoint_engine for accelerating, default `True`|
| `async_training.checkpoint_engine.overlap_broadcast_and_consume` | When use checkpoint_engine, whether to overlap broadcast and load_weights, default `False`|
| `async_training.checkpoint_engine.device_buffer_size_M` | When use checkpoint_engine, the user-specific bucket size (MB), default `4096`|
| `async_training.checkpoint_engine.bypass_cpu` | Whether bypass cpu memory when synchronizing parameters, default `False`|

**Further Explanation:**

Expand Down Expand Up @@ -196,6 +197,9 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a
* When disable `overlap_broadcast_and_consume`, the additional device memory overhead of
trainer rank is `2 * bucket_size`and rollout rank is `1 * bucket_size`。

* `async_training.checkpoint_engine.bypass_cpu`

Bypassing the CPU memory when synchronizing parameters.

### Supported Modes

Expand Down
5 changes: 5 additions & 0 deletions recipe/fully_async_policy/README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a
| `async_training.checkpoint_engine.enable`| 是否开启checkpoint_engine模式的加速,默认值True |
| `async_training.checkpoint_engine.overlap_broadcast_and_consume` | 启动checkpoint_engine时,是否在参数同步时在broadcast和加载之间使用流水,默认值False|
| `async_training.checkpoint_engine.device_buffer_size_M` | 启动checkpoint_engine时,组装的bucket的大小(MB),默认为4096 |
| `async_training.checkpoint_engine.bypass_cpu` | 是否在参数同步时跳过作为中转的cpu内存,默认值False|

**进一步的解释:**

Expand Down Expand Up @@ -157,6 +158,10 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a
* 在开启`overlap_broadcast_and_consume`时,trainer节点的临时额外显存开销为 `3 * bucket_size`, rollout节点的临时额外显存开销为`2 * bucket_size`。
* 在关闭`overlap_broadcast_and_consume`时,trainer节点的临时额外显存开销为 `2 * bucket_size`, rollout节点的临时额外显存开销为`1 * bucket_size`。

* `async_training.checkpoint_engine.bypass_cpu`

参数同步时跳过作为中转的CPU内存。

### 模式支持

1. on policy pipeline:
Expand Down
51 changes: 36 additions & 15 deletions recipe/fully_async_policy/checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
from ray.util.collective import collective

from verl.utils.device import (
get_device_name,
get_torch_device,
)
from verl.utils.device import get_device_name, get_torch_device

if TYPE_CHECKING:
from typing import TypeVar
Expand Down Expand Up @@ -263,7 +260,12 @@ class CheckpointEngine:
"""

def __init__(
self, current_rank: int, actor_ranks: list[int], rollout_ranks: list[int], device_buffer_size_M: int
self,
current_rank: int,
actor_ranks: list[int],
rollout_ranks: list[int],
device_buffer_size_M: int,
use_cpu_buffer: bool = True,
) -> None:
self.current_rank = current_rank
self.actor_ranks = actor_ranks
Expand All @@ -273,6 +275,7 @@ def __init__(
self.global_buckets: dict[int, list[MemoryBufferMeta]] = None
# min device_buffer_size for h2d and broadcast
self.device_buffer_size_M = device_buffer_size_M
self.use_cpu_buffer = use_cpu_buffer

# ipc config for broadcast in pipeline mode
self._zmq_ctx = zmq.Context()
Expand Down Expand Up @@ -342,6 +345,11 @@ def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
return idx, buffer

def register_gpu_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
"""Allocate gpu memory for a bucket."""
buffer = torch.empty(size, dtype=torch.uint8, device=get_torch_device().current_device())
return idx, buffer

def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
"""Copy a tensor into a pinned memory buffer."""
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
Expand All @@ -355,9 +363,16 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):

# Use thread pool to accelerate organize parameters into buckets
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
futures = [
executor.submit(register_pin_memory, idx, bucket.size) for idx, bucket in enumerate(local_buckets)
]
if self.use_cpu_buffer:
futures = [
executor.submit(register_pin_memory, idx, bucket.size)
for idx, bucket in enumerate(local_buckets)
]
else:
futures = [
executor.submit(register_gpu_memory, idx, bucket.size)
for idx, bucket in enumerate(local_buckets)
]
new_futures = []
for future in concurrent.futures.as_completed(futures):
idx, buffer = future.result()
Expand Down Expand Up @@ -424,11 +439,14 @@ def update_checkpoint(self, inference_model, group_name: str, overlap_broadcast_
for broadcasting and loading weights.
"""
try:
h2d_buffer: torch.Tensor | None = (
None
if self.current_rank in self.rollout_ranks
else torch.empty(self.bucket_size, dtype=torch.uint8, device=get_torch_device().current_device())
)
if self.use_cpu_buffer:
h2d_buffer: torch.Tensor | None = (
None
if self.current_rank in self.rollout_ranks
else torch.empty(self.bucket_size, dtype=torch.uint8, device=get_torch_device().current_device())
)
else:
h2d_buffer = None
# for pipeline mode, we need to allocate 2x buffer size
broadcast_load_buffer = torch.empty(
self.bucket_size * (2 if overlap_broadcast_and_consume else 1),
Expand Down Expand Up @@ -482,7 +500,7 @@ def update_weights_from_ipc_(socket_path):

for i in range(max_h2d_iter):
# Step 1: Each actor rank copy the parameter tensor into device memory
if i < len(self.memory_buffers):
if self.use_cpu_buffer and i < len(self.memory_buffers):
h2d_buffer[: local_buckets[i].size].data.copy_(self.memory_buffers[i].buffer)

# Step 2: Broadcast the device data in turn
Expand All @@ -495,7 +513,10 @@ def update_weights_from_ipc_(socket_path):
start = gidx % 2 * self.bucket_size if overlap_broadcast_and_consume else 0
buffer_b: torch.Tensor = broadcast_load_buffer[start : start + bucket.size]
if broadcast_rank == self.current_rank:
buffer_b.data.copy_(h2d_buffer[: bucket.size])
if self.use_cpu_buffer:
buffer_b.data.copy_(h2d_buffer[: bucket.size])
else:
buffer_b.data.copy_(self.memory_buffers[i].buffer)

# Broadcast the buffer to all ranks
collective.broadcast(buffer_b, src_rank=broadcast_rank, group_name=group_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ async_training:
# Enable the pipeline for broadcasting and updating parameters, but it requires more device memory
overlap_broadcast_and_consume: False

# Whether bypass cpu memory when synchronizing parameters
bypass_cpu: False

# Rollout config
rollout:

Expand Down
3 changes: 3 additions & 0 deletions recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ async_training:
# Enable the pipeline for broadcasting and updating parameters, but it requires more device memory
overlap_broadcast_and_consume: False

# Whether bypass cpu memory when synchronizing parameters
bypass_cpu: False

# Rollout config
rollout:

Expand Down
15 changes: 12 additions & 3 deletions recipe/fully_async_policy/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def init_checkpoint_engine(self, rank_offset: int, actor_num: int, rollout_num:
assert rank_offset == 0 or rank_offset == actor_num

self.checkpoint_engine = CheckpointEngine(
current_rank, actor_ranks, rollout_ranks, self.config.checkpoint_engine.device_buffer_size_M
current_rank,
actor_ranks,
rollout_ranks,
self.config.checkpoint_engine.device_buffer_size_M,
use_cpu_buffer=not self.config.checkpoint_engine.get("bypass_cpu", False),
)

def _get_actor_params(self):
Expand Down Expand Up @@ -120,15 +124,20 @@ def cache_actor_weights_to_cpu(self):
params = self._get_actor_params()
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
bypass_cpu = self.config.checkpoint_engine.get("bypass_cpu", False)

for tensor_idx, (key, _, _) in enumerate(self._weights_info):
origin_data = params[key]
if hasattr(origin_data, "full_tensor"):
origin_data = origin_data.full_tensor()

if tensor_idx % world_size == local_rank:
self.cpu_named_params[key] = origin_data.to("cpu", non_blocking=True)
get_torch_device().synchronize()
if bypass_cpu:
self.cpu_named_params[key] = origin_data
else:
self.cpu_named_params[key] = origin_data.to("cpu", non_blocking=True)
if not bypass_cpu:
get_torch_device().synchronize()

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"):
Expand Down
10 changes: 8 additions & 2 deletions recipe/fully_async_policy/megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,17 @@ def cache_actor_weights_to_cpu(self):
params_generator = self._get_actor_params_generator()
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
bypass_cpu = self.config.checkpoint_engine.get("bypass_cpu", False)

print(f"cache_actor_weights_to_cpu, local_rank:{local_rank}, world_size:{world_size}")
for tensor_idx, (key, tensor) in enumerate(params_generator):
if tensor_idx % world_size == local_rank:
self.cpu_named_params[key] = tensor.to("cpu", non_blocking=True)
get_torch_device().synchronize()
if bypass_cpu:
self.cpu_named_params[key] = tensor
else:
self.cpu_named_params[key] = tensor.to("cpu", non_blocking=True)
if not bypass_cpu:
get_torch_device().synchronize()

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"):
Expand Down