Skip to content

Commit 3deeea1

Browse files
wangxiaochao6wangxiaochao
andauthored
[bugfix] bugfix for PD disaggregate (#4319)
This PR is used to fix mooncake_connector in pcp/dcp case. When executing function update_done_task_count, it is necessary to ensure that both pcp/dcp and TP ranks have finished transferring KV cache. - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@2918c1b --------- Signed-off-by: wangxiaochao <[email protected]> Co-authored-by: wangxiaochao <[email protected]>
1 parent e332e27 commit 3deeea1

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

tests/ut/kv_connector/test_mooncake_connector.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def setUp(self):
8989
kv_caches: Dict[str, Any] = {}
9090
self.common_args = {
9191
'tp_rank': 1,
92-
'decode_tp_size': 4,
92+
'prefill_tp_size': 4,
9393
'local_engine_id': 'engine_1',
9494
'side_channel_host': 'localhost',
9595
'side_channel_port': 5555,
@@ -133,7 +133,7 @@ def setUp(self):
133133
kv_caches: Dict[str, Any] = {}
134134
self.common_args = {
135135
'tp_rank': 1,
136-
'decode_tp_size': 4,
136+
'prefill_tp_size': 4,
137137
'local_engine_id': 'engine_1',
138138
'side_channel_host': 'localhost',
139139
'side_channel_port': 5555,
@@ -171,7 +171,7 @@ def test_run_handles_get_meta_and_done_recv_msgs(self):
171171
free_port = s.getsockname()[1]
172172

173173
thread = KVCacheSendingThread(tp_rank=0,
174-
decode_tp_size=1,
174+
prefill_tp_size=1,
175175
local_engine_id="engine1",
176176
side_channel_host=host,
177177
side_channel_port=free_port,
@@ -237,7 +237,8 @@ def test_add_request(self):
237237
"remote_host": "localhost",
238238
"remote_handshake_port": 6666,
239239
"offset": 0,
240-
"num_need_pulls": 2
240+
"num_need_pulls": 2,
241+
"all_task_done": False
241242
}
242243
self.thread.add_request(
243244
request_id=test_req["request_id"],
@@ -247,7 +248,8 @@ def test_add_request(self):
247248
remote_host=test_req["remote_host"],
248249
remote_handshake_port=test_req["remote_handshake_port"],
249250
offset=test_req["offset"],
250-
num_need_pulls=test_req["num_need_pulls"])
251+
num_need_pulls=test_req["num_need_pulls"],
252+
all_task_done=test_req["all_task_done"])
251253
queued = self.thread.request_queue.get_nowait()
252254
self.assertEqual(queued["request_id"], "req1")
253255
self.assertEqual(queued["remote_host"], "localhost")
@@ -341,7 +343,8 @@ def setUp(self):
341343
"remote_handshake_port": 6666,
342344
"remote_transfer_port": 7777,
343345
"offset": 0,
344-
"num_need_pulls": 2
346+
"num_need_pulls": 2,
347+
"all_task_done": False
345348
}
346349
self.thread.task_tracker = MagicMock()
347350
self.engine.batch_transfer_sync_read.return_value = 0
@@ -485,7 +488,8 @@ def test_run_loop_normal(self, mock_handle):
485488
"remote_handshake_port": 6666,
486489
"remote_transfer_port": 7777,
487490
"offset": 0,
488-
"num_need_pulls": 2
491+
"num_need_pulls": 2,
492+
"all_task_done": False
489493
}
490494

491495
self.thread.request_queue.put(test_request)

vllm_ascend/distributed/mooncake_connector.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,14 @@ def _remove_delayed_requests(self, request_id: str):
150150

151151
class KVCacheSendingThread(threading.Thread):
152152

153-
def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str,
154-
side_channel_host: str, side_channel_port: int,
155-
metadata: MooncakeAgentMetadata, ready_event: threading.Event,
156-
kv_caches: dict[str, Any], pcp_rank: int):
153+
def __init__(self, tp_rank: int, prefill_tp_size: int,
154+
local_engine_id: str, side_channel_host: str,
155+
side_channel_port: int, metadata: MooncakeAgentMetadata,
156+
ready_event: threading.Event, kv_caches: dict[str, Any],
157+
pcp_rank: int):
157158
super().__init__(daemon=True, name="KVCacheSendingThread")
158159
self.tp_rank = tp_rank
159-
self.decode_tp_size = decode_tp_size
160+
self.prefill_tp_size = prefill_tp_size
160161
self.local_engine_id = local_engine_id
161162
self.side_channel_host = side_channel_host
162163
self.side_channel_port = side_channel_port
@@ -195,7 +196,7 @@ def run(self):
195196
# NOTE(rob): we need each rank to have a unique port. This hack to keeps
196197
# us moving. We will switch when moving to etcd or where we have a
197198
# single ZMQ socket in the scheduler.
198-
handshake_port = self.side_channel_port + self.pcp_rank * self.decode_tp_size \
199+
handshake_port = self.side_channel_port + self.pcp_rank * self.prefill_tp_size \
199200
+ self.tp_rank
200201
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
201202
logger.info("Starting listening on path: %s", path)
@@ -295,7 +296,7 @@ def __init__(self, tp_rank: int, tp_size: int, engine: TransferEngine,
295296
def add_request(self, request_id: str, local_block_ids: list[int],
296297
remote_block_ids: list[int], remote_engine_id: str,
297298
remote_host: str, remote_handshake_port: int, offset: int,
298-
num_need_pulls: int):
299+
num_need_pulls: int, all_task_done: bool):
299300
"""Add a new request to the queue for processing."""
300301
logger.debug(f"Adding request {request_id} to the queue.")
301302
self.request_queue.put({
@@ -306,7 +307,8 @@ def add_request(self, request_id: str, local_block_ids: list[int],
306307
"remote_host": remote_host,
307308
"remote_handshake_port": remote_handshake_port,
308309
"offset": offset,
309-
"num_need_pulls": num_need_pulls
310+
"num_need_pulls": num_need_pulls,
311+
"all_task_done": all_task_done
310312
})
311313

312314
def get_and_clear_finished_requests(self) -> set[str]:
@@ -335,8 +337,7 @@ def _handle_request(self, req_meta: dict[str, Any]):
335337
request_id = req_meta["request_id"]
336338
remote_host = req_meta["remote_host"]
337339
remote_handshake_port = req_meta["remote_handshake_port"]
338-
offset = req_meta["offset"]
339-
num_need_pulls = req_meta["num_need_pulls"]
340+
all_task_done = req_meta["all_task_done"]
340341

341342
try:
342343
logger.debug(
@@ -353,7 +354,7 @@ def _handle_request(self, req_meta: dict[str, Any]):
353354
# remote host.
354355
self._send_done_recv_signal(request_id, remote_host,
355356
remote_handshake_port)
356-
if offset == num_need_pulls - 1:
357+
if all_task_done:
357358
self.task_tracker.update_done_task_count(request_id)
358359
self.request_queue.task_done()
359360

@@ -1091,7 +1092,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
10911092
ready_event = threading.Event()
10921093
if self.kv_role == 'kv_producer':
10931094
self.kv_send_thread = KVCacheSendingThread(
1094-
self.tp_rank, self._decode_tp_size, self.engine_id,
1095+
self.tp_rank, self._prefill_tp_size, self.engine_id,
10951096
self.side_channel_host, self.side_channel_port, metadata,
10961097
ready_event, self.kv_caches, self.pcp_rank)
10971098
self.kv_send_thread.start()
@@ -1239,7 +1240,10 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
12391240
remote_handshake_port=remote_handshake_port_list[
12401241
pcp_dcp_rank][i],
12411242
offset=i,
1242-
num_need_pulls=self.num_need_pulls)
1243+
num_need_pulls=self.num_need_pulls,
1244+
all_task_done=(pcp_dcp_rank
1245+
== len(remote_handshake_port_list) - 1
1246+
and i == self.num_need_pulls - 1))
12431247

12441248
if self.kv_send_thread is not None:
12451249
for req_id, delay_start_time in metadata.requests_to_send.items():

0 commit comments

Comments
 (0)