Skip to content

Commit 57c701f

Browse files
authored
fix: expose prefill worker id in disagg (#4563)
Signed-off-by: PeaBrane <[email protected]>
1 parent 550bf98 commit 57c701f

File tree

3 files changed

+458
-100
lines changed

3 files changed

+458
-100
lines changed

lib/llm/src/kv_router.rs

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -623,14 +623,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
623623
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
624624
backend_input.dp_rank = Some(dp_rank);
625625

626-
// Check if worker_id is requested in extra_fields
627-
let should_populate_worker_id = backend_input
628-
.extra_fields
629-
.as_deref()
630-
.unwrap_or(&[])
631-
.iter()
632-
.any(|s| s == "worker_id");
633-
634626
// Get prefill worker ID if available (stored by PrefillRouter)
635627
// In aggregated mode, prefill_worker_id is None, so we use decode_worker_id for both
636628
let decode_worker_id = instance_id;
@@ -672,24 +664,30 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
672664
prefill_marked = true;
673665
}
674666

675-
// Inject worker_id in first item's disaggregated_params if requested
676-
if first_item && should_populate_worker_id {
677-
if let Some(ref mut data) = item.data {
678-
// Add worker_id to disaggregated_params
679-
let worker_id_json = json!({
680-
"prefill_worker_id": prefill_worker_id,
681-
"decode_worker_id": decode_worker_id,
682-
});
683-
684-
if let Some(ref mut params) = data.disaggregated_params {
685-
if let Some(obj) = params.as_object_mut() {
686-
obj.insert("worker_id".to_string(), worker_id_json);
687-
}
688-
} else {
689-
data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
690-
}
691-
}
667+
// Always inject worker_id in first item's disaggregated_params
668+
// This is needed for:
669+
// 1. PrefillRouter to know which prefill worker was chosen
670+
// 2. Client response when extra_fields contains "worker_id"
671+
if first_item {
692672
first_item = false;
673+
674+
let Some(ref mut data) = item.data else {
675+
yield item;
676+
continue;
677+
};
678+
679+
// prefill_worker_id comes from context (set by PrefillRouter) or falls back to instance_id
680+
// decode_worker_id is always the current instance_id
681+
let worker_id_json = json!({
682+
"prefill_worker_id": prefill_worker_id,
683+
"decode_worker_id": decode_worker_id,
684+
});
685+
686+
if let Some(obj) = data.disaggregated_params.as_mut().and_then(|p| p.as_object_mut()) {
687+
obj.insert("worker_id".to_string(), worker_id_json);
688+
} else {
689+
data.disaggregated_params = Some(json!({"worker_id": worker_id_json}));
690+
}
693691
}
694692

695693
yield item;

tests/router/common.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
frontend_port: int,
3737
namespace: str,
3838
store_backend: str = "etcd",
39+
enforce_disagg: bool = False,
3940
):
4041
command = [
4142
"python3",
@@ -53,6 +54,9 @@ def __init__(
5354
namespace,
5455
]
5556

57+
if enforce_disagg:
58+
command.append("--enforce-disagg")
59+
5660
super().__init__(
5761
command=command,
5862
timeout=60,
@@ -1490,6 +1494,196 @@ def sort_key(event):
14901494
logger.info("Indexers sync test completed successfully")
14911495

14921496

1497+
def _test_router_disagg_decisions(
1498+
prefill_workers,
1499+
decode_workers,
1500+
block_size: int,
1501+
request,
1502+
frontend_port: int,
1503+
test_payload: dict,
1504+
store_backend: str = "etcd",
1505+
):
1506+
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup via HTTP frontend.
1507+
1508+
Assumes prefill_workers and decode_workers are already initialized. This function manages
1509+
router lifecycle and sends progressive requests with overlapping prefixes.
1510+
1511+
This test:
1512+
1. Starts the KV router frontend with disagg support
1513+
2. Sends 4 progressive requests where each extends the previous tokens by block_size
1514+
3. Extracts prefill_worker_id and decode_worker_id from response nvext
1515+
4. Verifies all prefill_worker_ids are the same (due to prefix reuse routing)
1516+
5. Verifies prefill_worker_id is NOT in the set of decode_worker_ids (true disagg)
1517+
1518+
Args:
1519+
prefill_workers: Prefill workers already initialized with __enter__()
1520+
decode_workers: Decode workers already initialized with __enter__()
1521+
block_size: Block size for KV cache
1522+
request: Pytest request fixture for managing resources
1523+
frontend_port: Port for the frontend HTTP server
1524+
test_payload: Base test payload to send to /v1/chat/completions
1525+
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
1526+
1527+
Raises:
1528+
AssertionError: If prefill_worker_ids differ across requests (prefix reuse failure)
1529+
AssertionError: If prefill_worker_id is in decode_worker_ids (not true disagg)
1530+
"""
1531+
try:
1532+
# Start KV router frontend - uses decode_workers namespace for discovery
1533+
# The frontend will auto-discover both prefill and decode workers
1534+
logger.info(
1535+
f"Starting KV router frontend on port {frontend_port} for disagg test"
1536+
)
1537+
kv_router = KVRouterProcess(
1538+
request,
1539+
block_size,
1540+
frontend_port,
1541+
decode_workers.namespace,
1542+
store_backend,
1543+
enforce_disagg=True,
1544+
)
1545+
kv_router.__enter__()
1546+
1547+
frontend_url = f"http://localhost:{frontend_port}"
1548+
chat_url = f"{frontend_url}/v1/chat/completions"
1549+
1550+
# Wait for workers to register with frontend
1551+
logger.info(
1552+
"Waiting for prefill and decode workers to register with frontend..."
1553+
)
1554+
asyncio.run(
1555+
wait_for_frontend_ready(
1556+
frontend_url=frontend_url,
1557+
expected_num_workers=decode_workers.num_workers,
1558+
timeout=120,
1559+
)
1560+
)
1561+
1562+
async def send_progressive_requests():
1563+
"""Send 4 progressive requests with overlapping prefixes and collect worker IDs."""
1564+
prefill_worker_ids = []
1565+
decode_worker_ids = []
1566+
1567+
# Generate base tokens for progressive prefix extension
1568+
base_content = test_payload["messages"][0]["content"]
1569+
1570+
async with aiohttp.ClientSession() as session:
1571+
for i in range(4):
1572+
# Build progressive content by repeating base content
1573+
# Each iteration adds more content to extend the prefix
1574+
progressive_content = " ".join([base_content] * (i + 1))
1575+
1576+
# Create payload with worker_id in extra_fields to get prefill/decode worker IDs
1577+
payload = {
1578+
**test_payload,
1579+
"messages": [
1580+
{
1581+
"role": "user",
1582+
"content": progressive_content,
1583+
}
1584+
],
1585+
"nvext": {"extra_fields": ["worker_id"]},
1586+
"stream": True,
1587+
}
1588+
1589+
logger.info(
1590+
f"Sending request {i + 1}/4 with progressive prefix "
1591+
f"(~{len(progressive_content)} chars)"
1592+
)
1593+
1594+
async with session.post(chat_url, json=payload) as response:
1595+
assert (
1596+
response.status == 200
1597+
), f"Request {i + 1} failed with status {response.status}"
1598+
1599+
# Collect all chunks and look for nvext with worker_id
1600+
prefill_wid = None
1601+
decode_wid = None
1602+
1603+
async for line in response.content:
1604+
if not line:
1605+
continue
1606+
1607+
line_str = line.decode("utf-8", errors="replace").strip()
1608+
if not line_str.startswith("data:"):
1609+
continue
1610+
1611+
data_str = line_str[5:].strip()
1612+
if data_str == "[DONE]":
1613+
break
1614+
1615+
try:
1616+
data = json.loads(data_str)
1617+
# Check for nvext.worker_id in the response
1618+
nvext = data.get("nvext", {})
1619+
worker_id_info = nvext.get("worker_id", {})
1620+
1621+
if worker_id_info:
1622+
if "prefill_worker_id" in worker_id_info:
1623+
prefill_wid = worker_id_info[
1624+
"prefill_worker_id"
1625+
]
1626+
if "decode_worker_id" in worker_id_info:
1627+
decode_wid = worker_id_info["decode_worker_id"]
1628+
1629+
except json.JSONDecodeError:
1630+
continue
1631+
1632+
logger.info(
1633+
f"Request {i + 1}: prefill_worker_id={prefill_wid}, "
1634+
f"decode_worker_id={decode_wid}"
1635+
)
1636+
1637+
if prefill_wid is not None:
1638+
prefill_worker_ids.append(prefill_wid)
1639+
if decode_wid is not None:
1640+
decode_worker_ids.append(decode_wid)
1641+
1642+
# Small delay between requests
1643+
await asyncio.sleep(0.5)
1644+
1645+
return prefill_worker_ids, decode_worker_ids
1646+
1647+
# Run the progressive requests
1648+
prefill_ids, decode_ids = asyncio.run(send_progressive_requests())
1649+
1650+
logger.info(f"Collected prefill_worker_ids: {prefill_ids}")
1651+
logger.info(f"Collected decode_worker_ids: {decode_ids}")
1652+
1653+
# Verify we got worker IDs from all requests
1654+
assert len(prefill_ids) == 4, (
1655+
f"Expected 4 prefill_worker_ids, got {len(prefill_ids)}. "
1656+
f"Make sure nvext.extra_fields=['worker_id'] is being processed."
1657+
)
1658+
1659+
# Verify all prefill_worker_ids are the same (prefix reuse)
1660+
unique_prefill_ids = set(prefill_ids)
1661+
assert len(unique_prefill_ids) == 1, (
1662+
f"Expected all prefill requests to route to the same worker due to prefix reuse, "
1663+
f"but found {len(unique_prefill_ids)} unique prefill_worker_ids: {unique_prefill_ids}. "
1664+
f"Full list: {prefill_ids}"
1665+
)
1666+
1667+
# Verify prefill_worker_id is NOT in decode_worker_ids (true disagg)
1668+
unique_decode_ids = set(decode_ids)
1669+
prefill_id = prefill_ids[0]
1670+
assert prefill_id not in unique_decode_ids, (
1671+
f"Prefill worker {prefill_id} should NOT be in decode workers {unique_decode_ids}. "
1672+
f"This suggests disaggregated mode is not working correctly - "
1673+
f"prefill and decode should use separate worker pools."
1674+
)
1675+
1676+
logger.info(
1677+
f"Successfully verified disaggregated routing:\n"
1678+
f" - All 4 requests routed to same prefill_worker_id={prefill_id} (prefix reuse)\n"
1679+
f" - Prefill worker is NOT in decode worker set {unique_decode_ids} (true disagg)"
1680+
)
1681+
1682+
finally:
1683+
if "kv_router" in locals():
1684+
kv_router.__exit__(None, None, None)
1685+
1686+
14931687
def _test_router_decisions(
14941688
engine_workers,
14951689
endpoint,

0 commit comments

Comments
 (0)