@@ -2045,39 +2045,64 @@ def batch_task():
20452045 batch_task ()
20462046
20472047
2048- def validate_stats (results ,
2049- pytorch_backend ,
2050- max_tokens ,
2051- enable_iter_req_stats = False ):
2048+ def validate_stats (
2049+ * ,
2050+ results ,
2051+ pytorch_backend ,
2052+ max_tokens ,
2053+ use_overlap = False ,
2054+ enable_chunked_prefill = False ,
2055+ enable_iter_req_stats = False ,
2056+ ):
20522057 assert results
2053- assert len (results ) == max_tokens if pytorch_backend else max_tokens + 1
2058+ expected_num_results = max_tokens if pytorch_backend else max_tokens + 1
2059+ if enable_chunked_prefill :
2060+ expected_num_results += 1
2061+ assert len (results ) == expected_num_results
2062+
2063+ context_iterations = 2 if enable_chunked_prefill else 1
2064+ generation_iterations = max_tokens - 1
20542065 for iter , result in enumerate (results ):
20552066 ifbStats = result ["inflightBatchingStats" ]
2056- expected_num_scheduled = 1 if ( iter < max_tokens ) else 0
2057- assert ifbStats [ "numScheduledRequests" ] == expected_num_scheduled
2058- if iter == 0 :
2067+
2068+ if iter < context_iterations :
2069+ assert ifbStats [ "numScheduledRequests" ] == 1
20592070 assert ifbStats ["numContextRequests" ] == 1
20602071 assert ifbStats ["numGenRequests" ] == 0
20612072 assert result ["numActiveRequests" ] == 1
2062- elif iter == max_tokens :
2063- assert ifbStats ["numContextRequests" ] == 0
2064- assert ifbStats ["numGenRequests" ] == 0
2065- assert result ["numActiveRequests" ] == 0
2066- else :
2073+ elif iter < (context_iterations + generation_iterations ):
2074+ assert ifbStats ["numScheduledRequests" ] == 1
20672075 assert ifbStats ["numContextRequests" ] == 0
20682076 assert ifbStats ["numGenRequests" ] == 1
20692077 assert result ["numActiveRequests" ] == 1
2078+ else :
2079+ assert ifbStats ["numScheduledRequests" ] == 0
2080+ assert ifbStats ["numContextRequests" ] == 0
2081+ assert ifbStats ["numGenRequests" ] == 0
2082+ assert result ["numActiveRequests" ] == 0
20702083
20712084 if enable_iter_req_stats :
20722085 assert "requestStats" in result
20732086 req_stats = result ["requestStats" ]
20742087 assert len (req_stats ) == 1
20752088 req_stat = req_stats [0 ]
2076- assert req_stat ["numGeneratedTokens" ] == iter + 1
2089+ if iter < (context_iterations - 1 ):
2090+ # If use_overlap, the stats are one iteration ahead
2091+ assert req_stat [
2092+ "stage" ] == "GENERATION_IN_PROGRESS" if use_overlap else "CONTEXT_IN_PROGRESS"
2093+ assert req_stat [
2094+ "contextPrefillPosition" ] == 54 if use_overlap else 32
2095+ assert req_stat ["numGeneratedTokens" ] == 0
2096+ elif iter < (context_iterations - 1 + generation_iterations ):
2097+ assert req_stat ["stage" ] == "GENERATION_IN_PROGRESS"
2098+ assert req_stat ["contextPrefillPosition" ] == 54
2099+ assert req_stat ["numGeneratedTokens" ] == iter - (
2100+ context_iterations - 1 ) + 1
2101+ else :
2102+ assert req_stat ["stage" ] == "GENERATION_COMPLETE"
2103+ assert req_stat ["contextPrefillPosition" ] == 54
2104+ assert req_stat ["numGeneratedTokens" ] == max_tokens
20772105 assert req_stat ["scheduled" ] == True
2078- assert req_stat [
2079- "stage" ] == "GENERATION_IN_PROGRESS" if iter + 1 < max_tokens else "GENERATION_COMPLETE"
2080- assert req_stat ["contextPrefillPosition" ] == 4
20812106
20822107 expected_num_completed = 1 if iter == len (results ) - 1 else 0
20832108
@@ -2087,9 +2112,11 @@ def validate_stats(results,
20872112
20882113
20892114def llm_get_stats_test_harness (tp_size : int = 1 ,
2115+ pp_size : int = 1 ,
20902116 return_context_logits : bool = False ,
20912117 pytorch_backend : bool = False ,
20922118 use_overlap : bool = False ,
2119+ enable_chunked_prefill : bool = False ,
20932120 enable_iter_req_stats : bool = False ):
20942121
20952122 if return_context_logits and pytorch_backend :
@@ -2103,6 +2130,7 @@ def llm_get_stats_test_harness(tp_size: int = 1,
21032130 print ("return_context_logits: " , return_context_logits )
21042131 print ("pytorch_backend: " , pytorch_backend )
21052132 print ("use_overlap: " , use_overlap )
2133+ print ("enable_chunked_prefill: " , enable_chunked_prefill )
21062134 print ("enable_iter_req_stats: " , enable_iter_req_stats )
21072135 print ("-------------" )
21082136
@@ -2113,6 +2141,10 @@ def llm_get_stats_test_harness(tp_size: int = 1,
21132141 llm_args_extra ["gather_generation_logits" ] = True
21142142 sampling_args_extra ["return_context_logits" ] = True
21152143
2144+ if enable_chunked_prefill :
2145+ llm_args_extra ["enable_chunked_prefill" ] = True
2146+ llm_args_extra ["max_num_tokens" ] = 32
2147+
21162148 if pytorch_backend :
21172149 llm_args_extra .update (
21182150 dict (enable_iter_perf_stats = True ,
@@ -2125,27 +2157,38 @@ def llm_get_stats_test_harness(tp_size: int = 1,
21252157 if not pytorch_backend :
21262158 llm_args_extra ["fast_build" ] = True
21272159
2128- llm = LLM_CLASS (model = llama_model_path ,
2129- kv_cache_config = global_kvcache_config ,
2130- tensor_parallel_size = tp_size ,
2131- ** llm_args_extra )
2160+ with LLM_CLASS (model = llama_model_path ,
2161+ kv_cache_config = global_kvcache_config ,
2162+ tensor_parallel_size = tp_size ,
2163+ pipeline_parallel_size = pp_size ,
2164+ ** llm_args_extra ) as llm :
21322165
2133- max_tokens = 5
2134- sampling_params = SamplingParams (max_tokens = max_tokens ,
2135- ** sampling_args_extra )
2166+ max_tokens = 5
2167+ sampling_params = SamplingParams (max_tokens = max_tokens ,
2168+ ** sampling_args_extra )
21362169
2137- for output in llm .generate (prompts , sampling_params = sampling_params ):
2138- print (output )
2170+ long_prompts = [
2171+ "A B C D E F G H I J K L M N O P Q R S T U V W X Y Z " * 2
2172+ ]
21392173
2140- results = llm .get_stats (2 )
2174+ for output in llm .generate (long_prompts ,
2175+ sampling_params = sampling_params ):
2176+ print (output )
2177+
2178+ results = llm .get_stats (2 )
21412179
2142- validate_stats (results , pytorch_backend , max_tokens , enable_iter_req_stats )
2180+ validate_stats (results = results ,
2181+ pytorch_backend = pytorch_backend ,
2182+ max_tokens = max_tokens ,
2183+ use_overlap = use_overlap ,
2184+ enable_chunked_prefill = enable_chunked_prefill ,
2185+ enable_iter_req_stats = enable_iter_req_stats )
21432186
2144- assert not llm .get_stats (2 )
2187+ assert not llm .get_stats (2 )
21452188
2146- # test that IterationResult()._done is properly set
2147- _ = llm .generate (prompts , sampling_params = sampling_params )
2148- assert llm .get_stats (2 )
2189+ # test that IterationResult()._done is properly set
2190+ _ = llm .generate (prompts , sampling_params = sampling_params )
2191+ assert llm .get_stats (2 )
21492192
21502193
21512194@pytest .mark .parametrize ("return_context_logits" , [True , False ])
@@ -2213,9 +2256,11 @@ def test_llm_get_queued_stats():
22132256
22142257
22152258def llm_get_stats_async_test_harness (tp_size : int = 1 ,
2259+ pp_size : int = 1 ,
22162260 return_context_logits : bool = False ,
22172261 pytorch_backend : bool = False ,
22182262 use_overlap : bool = False ,
2263+ enable_chunked_prefill : bool = False ,
22192264 enable_iter_req_stats : bool = False ):
22202265
22212266 if return_context_logits and pytorch_backend :
@@ -2229,6 +2274,7 @@ def llm_get_stats_async_test_harness(tp_size: int = 1,
22292274 print ("return_context_logits: " , return_context_logits )
22302275 print ("pytorch_backend: " , pytorch_backend )
22312276 print ("use_overlap: " , use_overlap )
2277+ print ("enable_chunked_prefill: " , enable_chunked_prefill )
22322278 print ("enable_iter_req_stats: " , enable_iter_req_stats )
22332279 print ("-------------" )
22342280
@@ -2238,6 +2284,10 @@ def llm_get_stats_async_test_harness(tp_size: int = 1,
22382284 llm_args_extra ["build_config" ] = BuildConfig (gather_context_logits = True )
22392285 sampling_args_extra ["return_context_logits" ] = True
22402286
2287+ if enable_chunked_prefill :
2288+ llm_args_extra ["enable_chunked_prefill" ] = True
2289+ llm_args_extra ["max_num_tokens" ] = 32
2290+
22412291 if pytorch_backend :
22422292 llm_args_extra .update (
22432293 dict (enable_iter_perf_stats = True ,
@@ -2248,38 +2298,47 @@ def llm_get_stats_async_test_harness(tp_size: int = 1,
22482298 LLM_CLASS = LLM
22492299 llm_args_extra ["fast_build" ] = True
22502300
2251- llm = LLM_CLASS (model = llama_model_path ,
2252- kv_cache_config = global_kvcache_config ,
2253- tensor_parallel_size = tp_size ,
2254- ** llm_args_extra )
2301+ with LLM_CLASS (model = llama_model_path ,
2302+ kv_cache_config = global_kvcache_config ,
2303+ tensor_parallel_size = tp_size ,
2304+ ** llm_args_extra ) as llm :
22552305
2256- max_tokens = 6
2257- sampling_params = SamplingParams (max_tokens = max_tokens ,
2258- ** sampling_args_extra )
2306+ max_tokens = 6
2307+ sampling_params = SamplingParams (max_tokens = max_tokens ,
2308+ ** sampling_args_extra )
22592309
2260- async def task0 ():
2261- async for output in llm .generate_async (prompts [0 ],
2262- streaming = True ,
2263- sampling_params = sampling_params ):
2264- print (output )
2310+ long_prompts = [
2311+ "A B C D E F G H I J K L M N O P Q R S T U V W X Y Z " * 2
2312+ ]
22652313
2266- async def task1 ():
2267- results = []
2268- await asyncio . sleep (
2269- 3 ) # ensure there's stats to collect for the assertion
2270- async for stats in llm . get_stats_async ( timeout = 2 ):
2271- results . append ( stats )
2314+ async def task0 ():
2315+ async for output in llm . generate_async (
2316+ long_prompts [ 0 ],
2317+ streaming = True ,
2318+ sampling_params = sampling_params ):
2319+ print ( output )
22722320
2273- assert results
2274- if not use_overlap :
2275- validate_stats (results , pytorch_backend , max_tokens ,
2276- enable_iter_req_stats )
2321+ async def task1 ():
2322+ results = []
2323+ await asyncio .sleep (
2324+ 3 ) # ensure there's stats to collect for the assertion
2325+ async for stats in llm .get_stats_async (timeout = 2 ):
2326+ results .append (stats )
2327+
2328+ assert results
2329+ if not use_overlap :
2330+ validate_stats (results = results ,
2331+ pytorch_backend = pytorch_backend ,
2332+ max_tokens = max_tokens ,
2333+ use_overlap = use_overlap ,
2334+ enable_chunked_prefill = enable_chunked_prefill ,
2335+ enable_iter_req_stats = enable_iter_req_stats )
22772336
2278- async def main ():
2279- for i in range (2 ): # test recurrent usage
2280- await asyncio .gather (task0 (), task1 ())
2337+ async def main ():
2338+ for i in range (2 ): # test recurrent usage
2339+ await asyncio .gather (task0 (), task1 ())
22812340
2282- asyncio .run (main ())
2341+ asyncio .run (main ())
22832342
22842343
22852344@pytest .mark .parametrize ("return_context_logits" , [True , False ])
0 commit comments