@@ -36,6 +36,7 @@ def __init__(
3636 frontend_port : int ,
3737 namespace : str ,
3838 store_backend : str = "etcd" ,
39+ enforce_disagg : bool = False ,
3940 ):
4041 command = [
4142 "python3" ,
@@ -53,6 +54,9 @@ def __init__(
5354 namespace ,
5455 ]
5556
57+ if enforce_disagg :
58+ command .append ("--enforce-disagg" )
59+
5660 super ().__init__ (
5761 command = command ,
5862 timeout = 60 ,
@@ -1490,6 +1494,196 @@ def sort_key(event):
14901494 logger .info ("Indexers sync test completed successfully" )
14911495
14921496
1497+ def _test_router_disagg_decisions (
1498+ prefill_workers ,
1499+ decode_workers ,
1500+ block_size : int ,
1501+ request ,
1502+ frontend_port : int ,
1503+ test_payload : dict ,
1504+ store_backend : str = "etcd" ,
1505+ ):
1506+ """Validate KV cache prefix reuse in disaggregated prefill-decode setup via HTTP frontend.
1507+
1508+ Assumes prefill_workers and decode_workers are already initialized. This function manages
1509+ router lifecycle and sends progressive requests with overlapping prefixes.
1510+
1511+ This test:
1512+ 1. Starts the KV router frontend with disagg support
1513+ 2. Sends 4 progressive requests where each extends the previous tokens by block_size
1514+ 3. Extracts prefill_worker_id and decode_worker_id from response nvext
1515+ 4. Verifies all prefill_worker_ids are the same (due to prefix reuse routing)
1516+ 5. Verifies prefill_worker_id is NOT in the set of decode_worker_ids (true disagg)
1517+
1518+ Args:
1519+ prefill_workers: Prefill workers already initialized with __enter__()
1520+ decode_workers: Decode workers already initialized with __enter__()
1521+ block_size: Block size for KV cache
1522+ request: Pytest request fixture for managing resources
1523+ frontend_port: Port for the frontend HTTP server
1524+ test_payload: Base test payload to send to /v1/chat/completions
1525+ store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
1526+
1527+ Raises:
1528+ AssertionError: If prefill_worker_ids differ across requests (prefix reuse failure)
1529+ AssertionError: If prefill_worker_id is in decode_worker_ids (not true disagg)
1530+ """
1531+ try :
1532+ # Start KV router frontend - uses decode_workers namespace for discovery
1533+ # The frontend will auto-discover both prefill and decode workers
1534+ logger .info (
1535+ f"Starting KV router frontend on port { frontend_port } for disagg test"
1536+ )
1537+ kv_router = KVRouterProcess (
1538+ request ,
1539+ block_size ,
1540+ frontend_port ,
1541+ decode_workers .namespace ,
1542+ store_backend ,
1543+ enforce_disagg = True ,
1544+ )
1545+ kv_router .__enter__ ()
1546+
1547+ frontend_url = f"http://localhost:{ frontend_port } "
1548+ chat_url = f"{ frontend_url } /v1/chat/completions"
1549+
1550+ # Wait for workers to register with frontend
1551+ logger .info (
1552+ "Waiting for prefill and decode workers to register with frontend..."
1553+ )
1554+ asyncio .run (
1555+ wait_for_frontend_ready (
1556+ frontend_url = frontend_url ,
1557+ expected_num_workers = decode_workers .num_workers ,
1558+ timeout = 120 ,
1559+ )
1560+ )
1561+
1562+ async def send_progressive_requests ():
1563+ """Send 4 progressive requests with overlapping prefixes and collect worker IDs."""
1564+ prefill_worker_ids = []
1565+ decode_worker_ids = []
1566+
1567+ # Generate base tokens for progressive prefix extension
1568+ base_content = test_payload ["messages" ][0 ]["content" ]
1569+
1570+ async with aiohttp .ClientSession () as session :
1571+ for i in range (4 ):
1572+ # Build progressive content by repeating base content
1573+ # Each iteration adds more content to extend the prefix
1574+ progressive_content = " " .join ([base_content ] * (i + 1 ))
1575+
1576+ # Create payload with worker_id in extra_fields to get prefill/decode worker IDs
1577+ payload = {
1578+ ** test_payload ,
1579+ "messages" : [
1580+ {
1581+ "role" : "user" ,
1582+ "content" : progressive_content ,
1583+ }
1584+ ],
1585+ "nvext" : {"extra_fields" : ["worker_id" ]},
1586+ "stream" : True ,
1587+ }
1588+
1589+ logger .info (
1590+ f"Sending request { i + 1 } /4 with progressive prefix "
1591+ f"(~{ len (progressive_content )} chars)"
1592+ )
1593+
1594+ async with session .post (chat_url , json = payload ) as response :
1595+ assert (
1596+ response .status == 200
1597+ ), f"Request { i + 1 } failed with status { response .status } "
1598+
1599+ # Collect all chunks and look for nvext with worker_id
1600+ prefill_wid = None
1601+ decode_wid = None
1602+
1603+ async for line in response .content :
1604+ if not line :
1605+ continue
1606+
1607+ line_str = line .decode ("utf-8" , errors = "replace" ).strip ()
1608+ if not line_str .startswith ("data:" ):
1609+ continue
1610+
1611+ data_str = line_str [5 :].strip ()
1612+ if data_str == "[DONE]" :
1613+ break
1614+
1615+ try :
1616+ data = json .loads (data_str )
1617+ # Check for nvext.worker_id in the response
1618+ nvext = data .get ("nvext" , {})
1619+ worker_id_info = nvext .get ("worker_id" , {})
1620+
1621+ if worker_id_info :
1622+ if "prefill_worker_id" in worker_id_info :
1623+ prefill_wid = worker_id_info [
1624+ "prefill_worker_id"
1625+ ]
1626+ if "decode_worker_id" in worker_id_info :
1627+ decode_wid = worker_id_info ["decode_worker_id" ]
1628+
1629+ except json .JSONDecodeError :
1630+ continue
1631+
1632+ logger .info (
1633+ f"Request { i + 1 } : prefill_worker_id={ prefill_wid } , "
1634+ f"decode_worker_id={ decode_wid } "
1635+ )
1636+
1637+ if prefill_wid is not None :
1638+ prefill_worker_ids .append (prefill_wid )
1639+ if decode_wid is not None :
1640+ decode_worker_ids .append (decode_wid )
1641+
1642+ # Small delay between requests
1643+ await asyncio .sleep (0.5 )
1644+
1645+ return prefill_worker_ids , decode_worker_ids
1646+
1647+ # Run the progressive requests
1648+ prefill_ids , decode_ids = asyncio .run (send_progressive_requests ())
1649+
1650+ logger .info (f"Collected prefill_worker_ids: { prefill_ids } " )
1651+ logger .info (f"Collected decode_worker_ids: { decode_ids } " )
1652+
1653+ # Verify we got worker IDs from all requests
1654+ assert len (prefill_ids ) == 4 , (
1655+ f"Expected 4 prefill_worker_ids, got { len (prefill_ids )} . "
1656+ f"Make sure nvext.extra_fields=['worker_id'] is being processed."
1657+ )
1658+
1659+ # Verify all prefill_worker_ids are the same (prefix reuse)
1660+ unique_prefill_ids = set (prefill_ids )
1661+ assert len (unique_prefill_ids ) == 1 , (
1662+ f"Expected all prefill requests to route to the same worker due to prefix reuse, "
1663+ f"but found { len (unique_prefill_ids )} unique prefill_worker_ids: { unique_prefill_ids } . "
1664+ f"Full list: { prefill_ids } "
1665+ )
1666+
1667+ # Verify prefill_worker_id is NOT in decode_worker_ids (true disagg)
1668+ unique_decode_ids = set (decode_ids )
1669+ prefill_id = prefill_ids [0 ]
1670+ assert prefill_id not in unique_decode_ids , (
1671+ f"Prefill worker { prefill_id } should NOT be in decode workers { unique_decode_ids } . "
1672+ f"This suggests disaggregated mode is not working correctly - "
1673+ f"prefill and decode should use separate worker pools."
1674+ )
1675+
1676+ logger .info (
1677+ f"Successfully verified disaggregated routing:\n "
1678+ f" - All 4 requests routed to same prefill_worker_id={ prefill_id } (prefix reuse)\n "
1679+ f" - Prefill worker is NOT in decode worker set { unique_decode_ids } (true disagg)"
1680+ )
1681+
1682+ finally :
1683+ if "kv_router" in locals ():
1684+ kv_router .__exit__ (None , None , None )
1685+
1686+
14931687def _test_router_decisions (
14941688 engine_workers ,
14951689 endpoint ,
0 commit comments