Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/checkpoint_engine/test_correctness_on_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
RayResourcePool,
split_resource_pool,
)
from verl.utils.device import get_device_name
from verl.workers.config import CheckpointEngineConfig, HFModelConfig, RolloutConfig


Expand Down Expand Up @@ -127,6 +128,53 @@ async def test_nixl_checkpoint_engine(
ray.shutdown()


@pytest.mark.skip(reason="temporary skip since our ci environment is not ready")
@pytest.mark.asyncio
@pytest.mark.parametrize("rebuild_group", [False])
@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)])
async def test_kimi_checkpoint_engine(
rebuild_group,
num_trainer,
num_rollout,
num_nodes=1,
num_gpus_per_node=8,
check_allclose=True,
model_path="~/models/Qwen/Qwen3-8B-Base",
):
model_path = os.path.expanduser(model_path)
ray.init(
runtime_env={
"env_vars": {
"NCCL_IB_HCA": "mlx5",
"VERL_LOGGING_LEVEL": "DEBUG",
}
}
)

# initialize config
checkpoint_engine_config = CheckpointEngineConfig(
backend="kimi_ckpt_engine", engine_kwargs={"kimi_ckpt_engine": {"rebuild_group": rebuild_group}}
)
model_config = HFModelConfig(path=model_path, use_remove_padding=True)
rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config)

# create trainer and rollout worker group
resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3)
resource_pool.get_placement_groups(device_name=get_device_name())
trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout])
trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config)
trainer.reset()
rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose)

# create checkpoint engine manager
checkpoint_manager = CheckpointEngineManager(backend="kimi_ckpt_engine", trainer=trainer, replicas=replicas)
for _ in range(3):
await checkpoint_manager.update_weights()
rollout.check_weights()

ray.shutdown()


if __name__ == "__main__":
test_nccl_checkpoint_engine(
rebuild_group=False,
Expand Down
47 changes: 47 additions & 0 deletions tests/checkpoint_engine/test_correctness_on_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,53 @@ async def test_hccl_checkpoint_engine(
ray.shutdown()


@pytest.mark.skip(reason="temporary skip since our ci environment is not ready")
@pytest.mark.asyncio
@pytest.mark.parametrize("rebuild_group", [False])
@pytest.mark.parametrize("num_trainer, num_rollout", [(4, 28)])
async def test_kimi_checkpoint_engine(
rebuild_group,
num_trainer,
num_rollout,
num_nodes=2,
num_gpus_per_node=16,
check_allclose=True,
model_path="~/models/Qwen/Qwen3-32B",
):
model_path = os.path.expanduser(model_path)
ray.init(
runtime_env={
"env_vars": {
"HCCL_CONNECT_TIMEOUT": "1500",
"VERL_LOGGING_LEVEL": "DEBUG",
}
}
)

# initialize config
checkpoint_engine_config = CheckpointEngineConfig(
backend="kimi_ckpt_engine", engine_kwargs={"kimi_ckpt_engine": {"rebuild_group": rebuild_group}}
)
model_config = HFModelConfig(path=model_path, use_remove_padding=True)
rollout_config = RolloutConfig(name="vllm", checkpoint_engine=checkpoint_engine_config)

# create trainer and rollout worker group
resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3)
resource_pool.get_placement_groups(device_name=get_device_name())
trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout])
trainer = create_trainer_worker_group(trainer_pool, model_config, checkpoint_engine_config)
trainer.reset()
rollout, replicas = await create_rollout_worker_group(rollout_pool, model_config, rollout_config, check_allclose)

# create checkpoint engine manager
checkpoint_manager = CheckpointEngineManager(backend="kimi_ckpt_engine", trainer=trainer, replicas=replicas)
for _ in range(3):
await checkpoint_manager.update_weights()
rollout.check_weights()

ray.shutdown()


if __name__ == "__main__":
test_hccl_checkpoint_engine(
rebuild_group=False,
Expand Down
22 changes: 17 additions & 5 deletions verl/checkpoint_engine/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,27 @@ Checkpoint Engine is an unified abstract layer to synchronize weights between va
|nccl|NCCL|all_gather+broadcast|NVIDIA GPU & NCCL|Very High|Low: rebuild nccl group|Off-policy training<br>- Trainer/rollout disaggregated<br>- Fixed clusters
|hccl|HCCL|all_gather+broadcast|Ascend NPU & HCCL| High|Low: rebuild hccl group|Off-policy training<br>- Trainer/rollout disaggregated<br>- Fixed clusters
|nixl|NIXL|all_gather+ring p2p|Various transport backends (D2D, H2H, H2D, etc)<br>- UCX<br>- UCCL<br>- Mooncacke|Medium/High|High: dynamic adjust ring topology|Off-policy training<br>- Trainer/rollout disaggregated<br>- Elastic rollout<br>- Rollout fault tolerance<br>- Heterogeneous hardware rollout
|kimi_ckpt_engine|MOONCAKE+NCCL/HCCL|p2p+broadcast|NVIDIA/Ascend|High|Low: rebuild communication group|Off-policy training<br>- Trainer/rollout disaggregated<br>- Save checkpoint each time

##### kimi_ckpt_engine detail:

In the kimi_ckpt_engine workflow, the trainer first offloads the weights to the CPU, and the rollout creates a sub communication group that includes all the cards for the rollout. Then, using Mooncake transfer engine, these weights are transmitted via P2P to a specific worker in the rollout, followed by a broadcast to all other rollout workers.

<img src="https://github.com/kip-cxj/verl/blob/cxj/doc_imgs/docs/_static/kimi_ckpt_engine.png?raw=true" alt="kimi-ckpt-engine" width="50%">

This mode requires the P2P feature of checkpoint_engine. Please ensure you have installed it via pip install 'checkpoint-engine[p2p]' and that your version is 0.4.0 or higher.

In addition, during the installation of checkpoint-engine[p2p], the transfer engine will be installed. However, This library has no prebuilt packages for Ascend devices and must be compiled from source. For detailed compilation instructions, see: [transfer-engine: ascend direct](https://github.com/kvcache-ai/Mooncake/blob/main/docs/source/design/transfer-engine/ascend_direct_transport.md)

### Benchmark
1. benchmark setup
- model: Qwen/Qwen3-30B-A3B-Base
- trainer: fsdp world_size=2
- trainer: fsdp world_size=2 (since Ascend 910C has 64GB of HBM, we set world_size=4)
- rollout: num_rollout=30 (only receive weight without cuda ipc to vllm/sglang)
```bash
python3 tests/checkpoint_engine/test_nixl_checkpoint_engine.py
python3 tests/checkpoint_engine/test_nccl_checkpoint_engine.py
python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py
pytest tests/checkpoint_engine/test_correctness_on_gpu.py
pytest tests/checkpoint_engine/test_correctness_on_npu.py
pytest tests/checkpoint_engine/test_special_server_adapter.py
```

2. benchmark result
Expand All @@ -36,4 +47,5 @@ python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py
|----|----|----|----|
|4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NCCL | ~7 | 8.25|
|4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NIXL | ~7 | 8.25|
|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3|
|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3|
|2*16 Ascend 910C, inner suppernode| kimi_ckpt_engine | offload: 7 update: 3.5 | 16.5|
8 changes: 7 additions & 1 deletion verl/checkpoint_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,16 @@
except ImportError:
HCCLCheckpointEngine = None


try:
from .nixl_checkpoint_engine import NIXLCheckpointEngine

__all__ += ["NIXLCheckpointEngine"]
except ImportError:
NIXLCheckpointEngine = None

try:
from .kimi_checkpoint_engine import KIMICheckpointEngine

__all__ += ["KIMICheckpointEngine"]
except ImportError:
KIMICheckpointEngine = None
Loading
Loading