@@ -106,6 +106,7 @@ class BatchState:
106106class BatchStatePP (BatchState ):
107107 microbatch_id : int = - 1
108108 scheduled_ctx_reqs : list [LlmRequest ] = None
109+ finished_ctx_reqs : list [LlmRequest ] = None
109110
110111
111112class PyExecutor :
@@ -833,9 +834,13 @@ def _executor_loop_pp(self):
833834 can_queue = self ._can_queue (scheduled_batch )
834835
835836 if not can_queue :
837+ logger .debug (
838+ f"microbatch { microbatch_id } cannot be queued, skipping"
839+ )
836840 self .micro_batches [microbatch_id ] = None
837841 else :
838- self ._add_inflight_ids (scheduled_batch )
842+ logger .debug (f"microbatch { microbatch_id } can be queued" )
843+ finished_ctx_reqs = self ._add_inflight_ids (scheduled_batch )
839844
840845 if self .kv_cache_transceiver :
841846 # For generation requests which have completed KV cache transfer
@@ -895,6 +900,7 @@ def _executor_loop_pp(self):
895900 iter_stats = iter_stats ,
896901 microbatch_id = microbatch_id ,
897902 scheduled_ctx_reqs = scheduled_batch .context_requests ,
903+ finished_ctx_reqs = finished_ctx_reqs ,
898904 )
899905
900906 self .micro_batches [microbatch_id ] = batch_state
@@ -949,6 +955,8 @@ def _executor_loop_pp(self):
949955 finished_requests = []
950956 if previous_batch is not None :
951957 with torch .cuda .nvtx .range ("_handle_previous_batch_pp" ):
958+ sample_state = previous_batch .sample_state
959+ sample_state .scheduled_requests .context_requests = previous_batch .finished_ctx_reqs
952960 self ._update_requests (previous_batch .sample_state )
953961
954962 if self .block_reuse_enabled and not self .kv_cache_manager .is_vswa and self .kv_cache_transceiver :
@@ -980,7 +988,8 @@ def _executor_loop_pp(self):
980988 self .resource_manager .update_resources (
981989 previous_scheduled_batch , attn_metadata ,
982990 kv_cache_dtype_byte_size )
983- self ._remove_inflight_ids (previous_scheduled_batch )
991+
992+ self ._remove_inflight_ids (previous_batch )
984993
985994 self .wait_on_pp_send_handles (prev_microbatch_id )
986995 self .micro_batches [prev_microbatch_id ] = None
@@ -2486,12 +2495,32 @@ def _pause_requests(self, requests_to_pause):
24862495
24872496 def _add_inflight_ids (self , scheduled_requests ):
24882497 """Add reqids of current requests to self.inflight_req_ids."""
2489- for req in scheduled_requests .all_requests ():
2498+ finished_ctx_reqs = []
2499+ for req in scheduled_requests .context_requests :
2500+ if req .is_last_context_chunk :
2501+ logger .debug (
2502+ f"Context request with ID { req .request_id } added to DECODER model inflight set"
2503+ )
2504+ self .inflight_req_ids .insert (req .request_id )
2505+ finished_ctx_reqs .append (req )
2506+ for req in scheduled_requests .generation_requests :
2507+ logger .debug (
2508+ f"Generation request with ID { req .request_id } added to DECODER model inflight set"
2509+ )
24902510 self .inflight_req_ids .insert (req .request_id )
2511+ return finished_ctx_reqs
24912512
2492- def _remove_inflight_ids (self , scheduled_requests ):
2513+ def _remove_inflight_ids (self , batch_state : BatchStatePP ):
24932514 """Remove reqids of current requests from self.inflight_req_ids."""
2494- for req in scheduled_requests .all_requests ():
2515+ for req in batch_state .finished_ctx_reqs :
2516+ logger .debug (
2517+ f"Context request with ID { req .request_id } removed from DECODER model inflight set"
2518+ )
2519+ self .inflight_req_ids .erase (req .request_id )
2520+ for req in batch_state .sample_state .scheduled_requests .generation_requests :
2521+ logger .debug (
2522+ f"Generation request with ID { req .request_id } removed from DECODER model inflight set"
2523+ )
24952524 self .inflight_req_ids .erase (req .request_id )
24962525
24972526 def _handle_speculative_decoding (self , scheduled_batch , previous_tensors ,
0 commit comments