Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,18 +449,57 @@ def _split_train_data_by_dp(self, data, dp_size):
return rollout_data_refs


def _compute_prefill_engine_indices(num_engines: int, num_engines_per_node: int, prefill_count: int) -> set[int]:
"""
Compute which engine indices should be prefill (P) nodes.

Strategy: Spread P nodes across physical nodes as evenly as possible.
- First, try to place one P node per physical node (round-robin)
- If more P nodes are needed, continue with second round, etc.

Args:
num_engines: Total number of engines
num_engines_per_node: Number of engines per physical node
prefill_count: Number of prefill engines needed

Returns:
Set of engine indices that should be prefill nodes
"""
if prefill_count <= 0:
return set()

num_nodes = (num_engines + num_engines_per_node - 1) // num_engines_per_node
prefill_indices = set()

# Round-robin across nodes: slot 0 of each node, then slot 1, etc.
for slot in range(num_engines_per_node):
for node in range(num_nodes):
engine_idx = node * num_engines_per_node + slot
if engine_idx < num_engines and len(prefill_indices) < prefill_count:
prefill_indices.add(engine_idx)
if len(prefill_indices) >= prefill_count:
break

return prefill_indices


def init_rollout_engines(args, pg, all_rollout_engines):
if args.debug_train_only:
return 0

num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node)
num_engines = args.rollout_num_gpus // num_gpu_per_engine
assert len(all_rollout_engines) == num_engines

prefill_engine_indices = set()
if args.prefill_num_servers is not None:
prefill_num_servers = args.prefill_num_servers * args.rollout_num_gpus_per_engine // num_gpu_per_engine
prefill_count = args.prefill_num_servers * args.rollout_num_gpus_per_engine // num_gpu_per_engine
assert (
num_engines > prefill_num_servers
), f"num_engines {num_engines} should be larger than prefill_num_servers {prefill_num_servers}"
num_engines > prefill_count
), f"num_engines {num_engines} should be larger than prefill_count {prefill_count}"
num_engines_per_node = max(1, args.num_gpus_per_node // num_gpu_per_engine)
prefill_engine_indices = _compute_prefill_engine_indices(num_engines, num_engines_per_node, prefill_count)
logger.info(f"PD separation enabled: prefill engines at indices {sorted(prefill_engine_indices)}")

pg, reordered_bundle_indices, reordered_gpu_ids = pg

Expand Down Expand Up @@ -496,10 +535,7 @@ def init_rollout_engines(args, pg, all_rollout_engines):

worker_type = "regular"
if args.prefill_num_servers is not None:
if i < prefill_num_servers:
worker_type = "prefill"
else:
worker_type = "decode"
worker_type = "prefill" if i in prefill_engine_indices else "decode"

rollout_engine = RolloutRayActor.options(
num_cpus=num_cpus,
Expand All @@ -522,7 +558,10 @@ def init_rollout_engines(args, pg, all_rollout_engines):
addr_and_ports = _allocate_rollout_engine_addr_and_ports_external(args=args, rollout_engines=rollout_engines)
else:
addr_and_ports = _allocate_rollout_engine_addr_and_ports_normal(
args=args, num_engines=num_engines, rollout_engines=rollout_engines
args=args,
num_engines=num_engines,
rollout_engines=rollout_engines,
prefill_engine_indices=prefill_engine_indices,
)

# TODO: don't ray.get here to overlap train actor init with rollout engine init.
Expand All @@ -549,7 +588,7 @@ def _allocate_rollout_engine_addr_and_ports_external(args, rollout_engines):
return addr_and_ports


def _allocate_rollout_engine_addr_and_ports_normal(*, args, num_engines, rollout_engines):
def _allocate_rollout_engine_addr_and_ports_normal(*, args, num_engines, rollout_engines, prefill_engine_indices):
# get ports
# there are 4 ports we need to allocate
# 1. server port
Expand All @@ -561,12 +600,6 @@ def _allocate_rollout_engine_addr_and_ports_normal(*, args, num_engines, rollout
)
addr_and_ports = [{} for _ in range(num_engines)]

# Calculate prefill limit to identify prefill engines
prefill_limit = 0
if args.prefill_num_servers is not None:
num_gpu_per_engine = min(args.rollout_num_gpus_per_engine, args.num_gpus_per_node)
prefill_limit = args.prefill_num_servers * args.rollout_num_gpus_per_engine // num_gpu_per_engine

visited_nodes = set()
for rank, engine in rollout_engines:
if rank // num_engines_per_node in visited_nodes:
Expand Down Expand Up @@ -606,7 +639,7 @@ def addr():
addr_and_ports[current_rank]["port"] = get_port()
addr_and_ports[current_rank]["nccl_port"] = get_port()

if args.prefill_num_servers is not None and current_rank < prefill_limit:
if args.prefill_num_servers is not None and current_rank in prefill_engine_indices:
addr_and_ports[current_rank]["disaggregation_bootstrap_port"] = get_port()

if args.rollout_num_gpus_per_engine > args.num_gpus_per_node:
Expand Down
Loading