diff --git a/recipe/fully_async_policy/README.md b/recipe/fully_async_policy/README.md index 8f460670c6f..d73708f9e5e 100644 --- a/recipe/fully_async_policy/README.md +++ b/recipe/fully_async_policy/README.md @@ -108,6 +108,7 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a | `async_training.checkpoint_engine.enable`| Whether to use checkpoint_engine for accelerating, default `True`| | `async_training.checkpoint_engine.overlap_broadcast_and_consume` | When use checkpoint_engine, whether to overlap broadcast and load_weights, default `False`| | `async_training.checkpoint_engine.device_buffer_size_M` | When use checkpoint_engine, the user-specific bucket size (MB), default `4096`| +| `async_training.checkpoint_engine.bypass_cpu` | Whether bypass cpu memory when synchronizing parameters, default `False`| **Further Explanation:** @@ -196,6 +197,9 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a * When disable `overlap_broadcast_and_consume`, the additional device memory overhead of trainer rank is `2 * bucket_size`and rollout rank is `1 * bucket_size`。 +* `async_training.checkpoint_engine.bypass_cpu` + + Bypassing the CPU memory when synchronizing parameters. ### Supported Modes diff --git a/recipe/fully_async_policy/README_zh.md b/recipe/fully_async_policy/README_zh.md index 71fb68f3ec6..d07c82d3144 100644 --- a/recipe/fully_async_policy/README_zh.md +++ b/recipe/fully_async_policy/README_zh.md @@ -85,6 +85,7 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a | `async_training.checkpoint_engine.enable`| 是否开启checkpoint_engine模式的加速,默认值True | | `async_training.checkpoint_engine.overlap_broadcast_and_consume` | 启动checkpoint_engine时,是否在参数同步时在broadcast和加载之间使用流水,默认值False| | `async_training.checkpoint_engine.device_buffer_size_M` | 启动checkpoint_engine时,组装的bucket的大小(MB),默认为4096 | +| `async_training.checkpoint_engine.bypass_cpu` | 是否在参数同步时跳过作为中转的cpu内存,默认值False| **进一步的解释:** @@ -157,6 +158,10 @@ https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_a * 在开启`overlap_broadcast_and_consume`时,trainer节点的临时额外显存开销为 `3 * bucket_size`, rollout节点的临时额外显存开销为`2 * bucket_size`。 * 在关闭`overlap_broadcast_and_consume`时,trainer节点的临时额外显存开销为 `2 * bucket_size`, rollout节点的临时额外显存开销为`1 * bucket_size`。 +* `async_training.checkpoint_engine.bypass_cpu` + + 参数同步时跳过作为中转的CPU内存。 + ### 模式支持 1. on policy pipeline: diff --git a/recipe/fully_async_policy/checkpoint_engine.py b/recipe/fully_async_policy/checkpoint_engine.py index 3128bed8d19..9f8531095bb 100644 --- a/recipe/fully_async_policy/checkpoint_engine.py +++ b/recipe/fully_async_policy/checkpoint_engine.py @@ -31,10 +31,7 @@ from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema from ray.util.collective import collective -from verl.utils.device import ( - get_device_name, - get_torch_device, -) +from verl.utils.device import get_device_name, get_torch_device if TYPE_CHECKING: from typing import TypeVar @@ -263,7 +260,12 @@ class CheckpointEngine: """ def __init__( - self, current_rank: int, actor_ranks: list[int], rollout_ranks: list[int], device_buffer_size_M: int + self, + current_rank: int, + actor_ranks: list[int], + rollout_ranks: list[int], + device_buffer_size_M: int, + use_cpu_buffer: bool = True, ) -> None: self.current_rank = current_rank self.actor_ranks = actor_ranks @@ -273,6 +275,7 @@ def __init__( self.global_buckets: dict[int, list[MemoryBufferMeta]] = None # min device_buffer_size for h2d and broadcast self.device_buffer_size_M = device_buffer_size_M + self.use_cpu_buffer = use_cpu_buffer # ipc config for broadcast in pipeline mode self._zmq_ctx = zmq.Context() @@ -342,6 +345,11 @@ def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]: buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) return idx, buffer + def register_gpu_memory(idx: int, size: int) -> tuple[int, torch.Tensor]: + """Allocate gpu memory for a bucket.""" + buffer = torch.empty(size, dtype=torch.uint8, device=get_torch_device().current_device()) + return idx, buffer + def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): """Copy a tensor into a pinned memory buffer.""" buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) @@ -355,9 +363,16 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): # Use thread pool to accelerate organize parameters into buckets with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: - futures = [ - executor.submit(register_pin_memory, idx, bucket.size) for idx, bucket in enumerate(local_buckets) - ] + if self.use_cpu_buffer: + futures = [ + executor.submit(register_pin_memory, idx, bucket.size) + for idx, bucket in enumerate(local_buckets) + ] + else: + futures = [ + executor.submit(register_gpu_memory, idx, bucket.size) + for idx, bucket in enumerate(local_buckets) + ] new_futures = [] for future in concurrent.futures.as_completed(futures): idx, buffer = future.result() @@ -424,11 +439,14 @@ def update_checkpoint(self, inference_model, group_name: str, overlap_broadcast_ for broadcasting and loading weights. """ try: - h2d_buffer: torch.Tensor | None = ( - None - if self.current_rank in self.rollout_ranks - else torch.empty(self.bucket_size, dtype=torch.uint8, device=get_torch_device().current_device()) - ) + if self.use_cpu_buffer: + h2d_buffer: torch.Tensor | None = ( + None + if self.current_rank in self.rollout_ranks + else torch.empty(self.bucket_size, dtype=torch.uint8, device=get_torch_device().current_device()) + ) + else: + h2d_buffer = None # for pipeline mode, we need to allocate 2x buffer size broadcast_load_buffer = torch.empty( self.bucket_size * (2 if overlap_broadcast_and_consume else 1), @@ -482,7 +500,7 @@ def update_weights_from_ipc_(socket_path): for i in range(max_h2d_iter): # Step 1: Each actor rank copy the parameter tensor into device memory - if i < len(self.memory_buffers): + if self.use_cpu_buffer and i < len(self.memory_buffers): h2d_buffer[: local_buckets[i].size].data.copy_(self.memory_buffers[i].buffer) # Step 2: Broadcast the device data in turn @@ -495,7 +513,10 @@ def update_weights_from_ipc_(socket_path): start = gidx % 2 * self.bucket_size if overlap_broadcast_and_consume else 0 buffer_b: torch.Tensor = broadcast_load_buffer[start : start + bucket.size] if broadcast_rank == self.current_rank: - buffer_b.data.copy_(h2d_buffer[: bucket.size]) + if self.use_cpu_buffer: + buffer_b.data.copy_(h2d_buffer[: bucket.size]) + else: + buffer_b.data.copy_(self.memory_buffers[i].buffer) # Broadcast the buffer to all ranks collective.broadcast(buffer_b, src_rank=broadcast_rank, group_name=group_name) diff --git a/recipe/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml b/recipe/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml index c0252390e54..5803c1aaad9 100644 --- a/recipe/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml +++ b/recipe/fully_async_policy/config/fully_async_ppo_megatron_trainer.yaml @@ -38,6 +38,9 @@ async_training: # Enable the pipeline for broadcasting and updating parameters, but it requires more device memory overlap_broadcast_and_consume: False + # Whether bypass cpu memory when synchronizing parameters + bypass_cpu: False + # Rollout config rollout: diff --git a/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml b/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml index 21562de54e7..0a5d2480d12 100644 --- a/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml +++ b/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml @@ -38,6 +38,9 @@ async_training: # Enable the pipeline for broadcasting and updating parameters, but it requires more device memory overlap_broadcast_and_consume: False + # Whether bypass cpu memory when synchronizing parameters + bypass_cpu: False + # Rollout config rollout: diff --git a/recipe/fully_async_policy/fsdp_workers.py b/recipe/fully_async_policy/fsdp_workers.py index 4d172fb657e..fe0674ec4d6 100644 --- a/recipe/fully_async_policy/fsdp_workers.py +++ b/recipe/fully_async_policy/fsdp_workers.py @@ -75,7 +75,11 @@ def init_checkpoint_engine(self, rank_offset: int, actor_num: int, rollout_num: assert rank_offset == 0 or rank_offset == actor_num self.checkpoint_engine = CheckpointEngine( - current_rank, actor_ranks, rollout_ranks, self.config.checkpoint_engine.device_buffer_size_M + current_rank, + actor_ranks, + rollout_ranks, + self.config.checkpoint_engine.device_buffer_size_M, + use_cpu_buffer=not self.config.checkpoint_engine.get("bypass_cpu", False), ) def _get_actor_params(self): @@ -120,6 +124,7 @@ def cache_actor_weights_to_cpu(self): params = self._get_actor_params() local_rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() + bypass_cpu = self.config.checkpoint_engine.get("bypass_cpu", False) for tensor_idx, (key, _, _) in enumerate(self._weights_info): origin_data = params[key] @@ -127,8 +132,12 @@ def cache_actor_weights_to_cpu(self): origin_data = origin_data.full_tensor() if tensor_idx % world_size == local_rank: - self.cpu_named_params[key] = origin_data.to("cpu", non_blocking=True) - get_torch_device().synchronize() + if bypass_cpu: + self.cpu_named_params[key] = origin_data + else: + self.cpu_named_params[key] = origin_data.to("cpu", non_blocking=True) + if not bypass_cpu: + get_torch_device().synchronize() @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): diff --git a/recipe/fully_async_policy/megatron_worker.py b/recipe/fully_async_policy/megatron_worker.py index f9f2c932a4f..c3ebb37dccc 100644 --- a/recipe/fully_async_policy/megatron_worker.py +++ b/recipe/fully_async_policy/megatron_worker.py @@ -129,11 +129,17 @@ def cache_actor_weights_to_cpu(self): params_generator = self._get_actor_params_generator() local_rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() + bypass_cpu = self.config.checkpoint_engine.get("bypass_cpu", False) + print(f"cache_actor_weights_to_cpu, local_rank:{local_rank}, world_size:{world_size}") for tensor_idx, (key, tensor) in enumerate(params_generator): if tensor_idx % world_size == local_rank: - self.cpu_named_params[key] = tensor.to("cpu", non_blocking=True) - get_torch_device().synchronize() + if bypass_cpu: + self.cpu_named_params[key] = tensor + else: + self.cpu_named_params[key] = tensor.to("cpu", non_blocking=True) + if not bypass_cpu: + get_torch_device().synchronize() @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"):