Skip to content

Commit 58b7cdc

Browse files
committed
[feat] Overlap context chunks in pipeline parallel mode
- Added `finished_ctx_reqs` to `BatchStatePP` to track completed context requests. - Updated `_add_inflight_ids` to return finished context requests for better state management. - Enhanced `_remove_inflight_ids` to utilize finished context requests from `BatchStatePP`. - Added debug logging for queuing decisions and inflight request management. Signed-off-by: Robin Kobus <[email protected]>
1 parent c8145ff commit 58b7cdc

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class BatchState:
106106
class BatchStatePP(BatchState):
107107
microbatch_id: int = -1
108108
scheduled_ctx_reqs: list[LlmRequest] = None
109+
finished_ctx_reqs: list[LlmRequest] = None
109110

110111

111112
class PyExecutor:
@@ -833,9 +834,13 @@ def _executor_loop_pp(self):
833834
can_queue = self._can_queue(scheduled_batch)
834835

835836
if not can_queue:
837+
logger.debug(
838+
f"microbatch {microbatch_id} cannot be queued, skipping"
839+
)
836840
self.micro_batches[microbatch_id] = None
837841
else:
838-
self._add_inflight_ids(scheduled_batch)
842+
logger.debug(f"microbatch {microbatch_id} can be queued")
843+
finished_ctx_reqs = self._add_inflight_ids(scheduled_batch)
839844

840845
if self.kv_cache_transceiver:
841846
# For generation requests which have completed KV cache transfer
@@ -895,6 +900,7 @@ def _executor_loop_pp(self):
895900
iter_stats=iter_stats,
896901
microbatch_id=microbatch_id,
897902
scheduled_ctx_reqs=scheduled_batch.context_requests,
903+
finished_ctx_reqs=finished_ctx_reqs,
898904
)
899905

900906
self.micro_batches[microbatch_id] = batch_state
@@ -949,6 +955,8 @@ def _executor_loop_pp(self):
949955
finished_requests = []
950956
if previous_batch is not None:
951957
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
958+
sample_state = previous_batch.sample_state
959+
sample_state.scheduled_requests.context_requests = previous_batch.finished_ctx_reqs
952960
self._update_requests(previous_batch.sample_state)
953961

954962
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
@@ -980,7 +988,8 @@ def _executor_loop_pp(self):
980988
self.resource_manager.update_resources(
981989
previous_scheduled_batch, attn_metadata,
982990
kv_cache_dtype_byte_size)
983-
self._remove_inflight_ids(previous_scheduled_batch)
991+
992+
self._remove_inflight_ids(previous_batch)
984993

985994
self.wait_on_pp_send_handles(prev_microbatch_id)
986995
self.micro_batches[prev_microbatch_id] = None
@@ -2486,12 +2495,32 @@ def _pause_requests(self, requests_to_pause):
24862495

24872496
def _add_inflight_ids(self, scheduled_requests):
24882497
"""Add reqids of current requests to self.inflight_req_ids."""
2489-
for req in scheduled_requests.all_requests():
2498+
finished_ctx_reqs = []
2499+
for req in scheduled_requests.context_requests:
2500+
if req.is_last_context_chunk:
2501+
logger.debug(
2502+
f"Context request with ID {req.request_id} added to DECODER model inflight set"
2503+
)
2504+
self.inflight_req_ids.insert(req.request_id)
2505+
finished_ctx_reqs.append(req)
2506+
for req in scheduled_requests.generation_requests:
2507+
logger.debug(
2508+
f"Generation request with ID {req.request_id} added to DECODER model inflight set"
2509+
)
24902510
self.inflight_req_ids.insert(req.request_id)
2511+
return finished_ctx_reqs
24912512

2492-
def _remove_inflight_ids(self, scheduled_requests):
2513+
def _remove_inflight_ids(self, batch_state: BatchStatePP):
24932514
"""Remove reqids of current requests from self.inflight_req_ids."""
2494-
for req in scheduled_requests.all_requests():
2515+
for req in batch_state.finished_ctx_reqs:
2516+
logger.debug(
2517+
f"Context request with ID {req.request_id} removed from DECODER model inflight set"
2518+
)
2519+
self.inflight_req_ids.erase(req.request_id)
2520+
for req in batch_state.sample_state.scheduled_requests.generation_requests:
2521+
logger.debug(
2522+
f"Generation request with ID {req.request_id} removed from DECODER model inflight set"
2523+
)
24952524
self.inflight_req_ids.erase(req.request_id)
24962525

24972526
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors,

0 commit comments

Comments
 (0)