@@ -150,13 +150,14 @@ def _remove_delayed_requests(self, request_id: str):
150150
151151class KVCacheSendingThread (threading .Thread ):
152152
153- def __init__ (self , tp_rank : int , decode_tp_size : int , local_engine_id : str ,
154- side_channel_host : str , side_channel_port : int ,
155- metadata : MooncakeAgentMetadata , ready_event : threading .Event ,
156- kv_caches : dict [str , Any ], pcp_rank : int ):
153+ def __init__ (self , tp_rank : int , prefill_tp_size : int ,
154+ local_engine_id : str , side_channel_host : str ,
155+ side_channel_port : int , metadata : MooncakeAgentMetadata ,
156+ ready_event : threading .Event , kv_caches : dict [str , Any ],
157+ pcp_rank : int ):
157158 super ().__init__ (daemon = True , name = "KVCacheSendingThread" )
158159 self .tp_rank = tp_rank
159- self .decode_tp_size = decode_tp_size
160+ self .prefill_tp_size = prefill_tp_size
160161 self .local_engine_id = local_engine_id
161162 self .side_channel_host = side_channel_host
162163 self .side_channel_port = side_channel_port
@@ -195,7 +196,7 @@ def run(self):
195196 # NOTE(rob): we need each rank to have a unique port. This hack to keeps
196197 # us moving. We will switch when moving to etcd or where we have a
197198 # single ZMQ socket in the scheduler.
198- handshake_port = self .side_channel_port + self .pcp_rank * self .decode_tp_size \
199+ handshake_port = self .side_channel_port + self .pcp_rank * self .prefill_tp_size \
199200 + self .tp_rank
200201 path = make_zmq_path ("tcp" , self .side_channel_host , handshake_port )
201202 logger .info ("Starting listening on path: %s" , path )
@@ -295,7 +296,7 @@ def __init__(self, tp_rank: int, tp_size: int, engine: TransferEngine,
295296 def add_request (self , request_id : str , local_block_ids : list [int ],
296297 remote_block_ids : list [int ], remote_engine_id : str ,
297298 remote_host : str , remote_handshake_port : int , offset : int ,
298- num_need_pulls : int ):
299+ num_need_pulls : int , all_task_done : bool ):
299300 """Add a new request to the queue for processing."""
300301 logger .debug (f"Adding request { request_id } to the queue." )
301302 self .request_queue .put ({
@@ -306,7 +307,8 @@ def add_request(self, request_id: str, local_block_ids: list[int],
306307 "remote_host" : remote_host ,
307308 "remote_handshake_port" : remote_handshake_port ,
308309 "offset" : offset ,
309- "num_need_pulls" : num_need_pulls
310+ "num_need_pulls" : num_need_pulls ,
311+ "all_task_done" : all_task_done
310312 })
311313
312314 def get_and_clear_finished_requests (self ) -> set [str ]:
@@ -335,8 +337,7 @@ def _handle_request(self, req_meta: dict[str, Any]):
335337 request_id = req_meta ["request_id" ]
336338 remote_host = req_meta ["remote_host" ]
337339 remote_handshake_port = req_meta ["remote_handshake_port" ]
338- offset = req_meta ["offset" ]
339- num_need_pulls = req_meta ["num_need_pulls" ]
340+ all_task_done = req_meta ["all_task_done" ]
340341
341342 try :
342343 logger .debug (
@@ -353,7 +354,7 @@ def _handle_request(self, req_meta: dict[str, Any]):
353354 # remote host.
354355 self ._send_done_recv_signal (request_id , remote_host ,
355356 remote_handshake_port )
356- if offset == num_need_pulls - 1 :
357+ if all_task_done :
357358 self .task_tracker .update_done_task_count (request_id )
358359 self .request_queue .task_done ()
359360
@@ -1091,7 +1092,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
10911092 ready_event = threading .Event ()
10921093 if self .kv_role == 'kv_producer' :
10931094 self .kv_send_thread = KVCacheSendingThread (
1094- self .tp_rank , self ._decode_tp_size , self .engine_id ,
1095+ self .tp_rank , self ._prefill_tp_size , self .engine_id ,
10951096 self .side_channel_host , self .side_channel_port , metadata ,
10961097 ready_event , self .kv_caches , self .pcp_rank )
10971098 self .kv_send_thread .start ()
@@ -1239,7 +1240,10 @@ def start_load_kv(self, metadata: MooncakeConnectorMetadata):
12391240 remote_handshake_port = remote_handshake_port_list [
12401241 pcp_dcp_rank ][i ],
12411242 offset = i ,
1242- num_need_pulls = self .num_need_pulls )
1243+ num_need_pulls = self .num_need_pulls ,
1244+ all_task_done = (pcp_dcp_rank
1245+ == len (remote_handshake_port_list ) - 1
1246+ and i == self .num_need_pulls - 1 ))
12431247
12441248 if self .kv_send_thread is not None :
12451249 for req_id , delay_start_time in metadata .requests_to_send .items ():
0 commit comments