@@ -368,6 +368,37 @@ def check_slots_status():
368368 # assert match_regex(re_content, res.body["content"])
369369
370370
371+ @pytest .mark .parametrize (
372+ "n_ctx,n_slots,n_predict_vals,expected_success" ,
373+ [
374+ (256 , 4 , [80 , 40 , 80 , 80 ], [True , True , True , True ]),
375+ (256 , 4 , [70 , 70 , 70 , 70 ], [False , False , False , False ]),
376+ (256 , 4 , [90 , 90 , 40 , 90 ], [False , False , True , False ]),
377+ (256 , 4 , [90 , 90 , 40 , 80 ], [True , True , True , True ]),
378+ ],
379+ )
380+ def test_completion_unified (n_ctx , n_slots , n_predict_vals , expected_success ):
381+ global server
382+ server .n_slots = n_slots
383+ server .kv_unified = True
384+ server .n_ctx = n_ctx
385+ server .start ()
386+ prompt = "A"
387+ tasks = []
388+ for n_predict in n_predict_vals :
389+ tasks .append ((server .make_request , ("POST" , "/completion" , {"prompt" : prompt , "n_predict" : n_predict })))
390+ results = parallel_function_calls (tasks )
391+ for res , n_predict , expect_ok in zip (results , n_predict_vals , expected_success ):
392+ if expect_ok :
393+ assert res .status_code == 200
394+ assert "content" in res .body
395+ if "timings" in res .body :
396+ assert res .body ["timings" ]["predicted_n" ] == n_predict
397+ else :
398+ assert res .status_code == 500
399+ assert "content" not in res .body
400+
401+
371402@pytest .mark .parametrize (
372403 "prompt,n_predict,response_fields" ,
373404 [
0 commit comments