@@ -2050,59 +2050,76 @@ def validate_stats(
20502050 results ,
20512051 pytorch_backend ,
20522052 max_tokens ,
2053+ pp_size = 1 ,
20532054 use_overlap = False ,
20542055 enable_chunked_prefill = False ,
20552056 enable_iter_req_stats = False ,
20562057):
20572058 assert results
2059+ for iter , result in enumerate (results ):
2060+ ifbStats = result ["inflightBatchingStats" ]
2061+ print (f"iter: { iter } , ifbStats: { ifbStats } " )
2062+
20582063 expected_num_results = max_tokens if pytorch_backend else max_tokens + 1
20592064 if enable_chunked_prefill :
20602065 expected_num_results += 1
20612066 assert len (results ) == expected_num_results
20622067
20632068 context_iterations = 2 if enable_chunked_prefill else 1
20642069 generation_iterations = max_tokens - 1
2070+ microbatch_id = 0
20652071 for iter , result in enumerate (results ):
20662072 ifbStats = result ["inflightBatchingStats" ]
20672073
20682074 if iter < context_iterations :
2069- assert ifbStats ["numScheduledRequests" ] == 1
2070- assert ifbStats ["numContextRequests" ] == 1
2071- assert ifbStats ["numGenRequests" ] == 0
2072- assert result ["numActiveRequests" ] == 1
2075+ assert ifbStats ["numScheduledRequests" ] == 1 , f"iter: { iter } "
2076+ assert ifbStats ["numContextRequests" ] == 1 , f"iter: { iter } "
2077+ assert ifbStats ["numGenRequests" ] == 0 , f"iter: { iter } "
2078+ assert result ["numActiveRequests" ] == 1 , f"iter: { iter } "
2079+ assert ifbStats ["microBatchId" ] == microbatch_id , f"iter: { iter } "
20732080 elif iter < (context_iterations + generation_iterations ):
2074- assert ifbStats ["numScheduledRequests" ] == 1
2075- assert ifbStats ["numContextRequests" ] == 0
2076- assert ifbStats ["numGenRequests" ] == 1
2077- assert result ["numActiveRequests" ] == 1
2081+ assert ifbStats ["numScheduledRequests" ] == 1 , f"iter: { iter } "
2082+ assert ifbStats ["numContextRequests" ] == 0 , f"iter: { iter } "
2083+ assert ifbStats ["numGenRequests" ] == 1 , f"iter: { iter } "
2084+ assert result ["numActiveRequests" ] == 1 , f"iter: { iter } "
2085+ assert ifbStats ["microBatchId" ] == microbatch_id , f"iter: { iter } "
20782086 else :
2079- assert ifbStats ["numScheduledRequests" ] == 0
2080- assert ifbStats ["numContextRequests" ] == 0
2081- assert ifbStats ["numGenRequests" ] == 0
2082- assert result ["numActiveRequests" ] == 0
2087+ assert ifbStats ["numScheduledRequests" ] == 0 , f"iter: { iter } "
2088+ assert ifbStats ["numContextRequests" ] == 0 , f"iter: { iter } "
2089+ assert ifbStats ["numGenRequests" ] == 0 , f"iter: { iter } "
2090+ assert result ["numActiveRequests" ] == 0 , f"iter: { iter } "
2091+ assert ifbStats ["microBatchId" ] == microbatch_id , f"iter: { iter } "
2092+
2093+ # In pipeline parallel mode, increment microbatch_id for each context iteration except the last one,
2094+ # since the context chunks can be scheduled in each iteration.
2095+ if pp_size > 1 and iter < context_iterations - 1 :
2096+ microbatch_id += 1
20832097
20842098 if enable_iter_req_stats :
2085- assert "requestStats" in result
2099+ assert "requestStats" in result , f"iter: { iter } "
20862100 req_stats = result ["requestStats" ]
2087- assert len (req_stats ) == 1
2101+ assert len (req_stats ) == 1 , f"iter: { iter } "
20882102 req_stat = req_stats [0 ]
20892103 if iter < (context_iterations - 1 ):
20902104 # If use_overlap, the stats are one iteration ahead
20912105 assert req_stat [
2092- "stage" ] == "GENERATION_IN_PROGRESS" if use_overlap else "CONTEXT_IN_PROGRESS"
2106+ "stage" ] == "GENERATION_IN_PROGRESS" if use_overlap else "CONTEXT_IN_PROGRESS" , f"iter: { iter } "
20932107 assert req_stat [
2094- "contextPrefillPosition" ] == 54 if use_overlap else 32
2095- assert req_stat ["numGeneratedTokens" ] == 0
2108+ "contextPrefillPosition" ] == 54 if use_overlap else 32 , f"iter: { iter } "
2109+ assert req_stat ["numGeneratedTokens" ] == 0 , f"iter: { iter } "
20962110 elif iter < (context_iterations - 1 + generation_iterations ):
2097- assert req_stat ["stage" ] == "GENERATION_IN_PROGRESS"
2098- assert req_stat ["contextPrefillPosition" ] == 54
2111+ assert req_stat [
2112+ "stage" ] == "GENERATION_IN_PROGRESS" , f"iter: { iter } "
2113+ assert req_stat ["contextPrefillPosition" ] == 54 , f"iter: { iter } "
20992114 assert req_stat ["numGeneratedTokens" ] == iter - (
2100- context_iterations - 1 ) + 1
2115+ context_iterations - 1 ) + 1 , f"iter: { iter } "
21012116 else :
2102- assert req_stat ["stage" ] == "GENERATION_COMPLETE"
2103- assert req_stat ["contextPrefillPosition" ] == 54
2104- assert req_stat ["numGeneratedTokens" ] == max_tokens
2105- assert req_stat ["scheduled" ] == True
2117+ assert req_stat [
2118+ "stage" ] == "GENERATION_COMPLETE" , f"iter: { iter } "
2119+ assert req_stat ["contextPrefillPosition" ] == 54 , f"iter: { iter } "
2120+ assert req_stat [
2121+ "numGeneratedTokens" ] == max_tokens , f"iter: { iter } "
2122+ assert req_stat ["scheduled" ] == True , f"iter: { iter } "
21062123
21072124 expected_num_completed = 1 if iter == len (results ) - 1 else 0
21082125
@@ -2178,6 +2195,7 @@ def llm_get_stats_test_harness(tp_size: int = 1,
21782195 results = llm .get_stats (2 )
21792196
21802197 validate_stats (results = results ,
2198+ pp_size = pp_size ,
21812199 pytorch_backend = pytorch_backend ,
21822200 max_tokens = max_tokens ,
21832201 use_overlap = use_overlap ,
@@ -2328,6 +2346,7 @@ async def task1():
23282346 assert results
23292347 if not use_overlap :
23302348 validate_stats (results = results ,
2349+ pp_size = pp_size ,
23312350 pytorch_backend = pytorch_backend ,
23322351 max_tokens = max_tokens ,
23332352 use_overlap = use_overlap ,
0 commit comments