File tree 2 files changed +5
-5
lines changed
2 files changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -262,7 +262,7 @@ def test_pld_spec_decode_inference(
262
262
num_speculative_tokens = num_speculative_tokens ,
263
263
)
264
264
# init qaic session
265
- target_model_session = QAICInferenceSession (target_model_qpc_path , device_ids = device_group )
265
+ target_model_session = QAICInferenceSession (target_model_qpc_path )
266
266
draft_model_session = None
267
267
268
268
# skip inputs/outputs buffers
@@ -453,7 +453,7 @@ def test_pld_spec_decode_inference(
453
453
del draft_model_session
454
454
generated_ids = np .asarray (generated_ids [0 ]).flatten ()
455
455
gen_len = generated_ids .shape [0 ]
456
- exec_info = target_model .generate (tokenizer , Constants .INPUT_STR , device_group )
456
+ exec_info = target_model .generate (tokenizer , Constants .INPUT_STR )
457
457
cloud_ai_100_tokens = exec_info .generated_ids [0 ][
458
458
:gen_len
459
459
] # Because we always run for single input and single batch size
Original file line number Diff line number Diff line change @@ -157,8 +157,8 @@ def test_spec_decode_inference(
157
157
full_batch_size = full_batch_size ,
158
158
)
159
159
# init qaic session
160
- target_model_session = QAICInferenceSession (target_model_qpc_path , device_ids = device_group )
161
- draft_model_session = QAICInferenceSession (draft_model_qpc_path , device_ids = device_group )
160
+ target_model_session = QAICInferenceSession (target_model_qpc_path )
161
+ draft_model_session = QAICInferenceSession (draft_model_qpc_path )
162
162
163
163
# skip inputs/outputs buffers
164
164
target_model_session .skip_buffers (set ([x for x in target_model_session .input_names if x .startswith ("past_" )]))
@@ -341,7 +341,7 @@ def test_spec_decode_inference(
341
341
del draft_model_session
342
342
generated_ids = np .asarray (generated_ids [0 ]).flatten ()
343
343
gen_len = generated_ids .shape [0 ]
344
- exec_info = draft_model .generate (tokenizer , Constants .INPUT_STR , device_group )
344
+ exec_info = draft_model .generate (tokenizer , Constants .INPUT_STR )
345
345
cloud_ai_100_tokens = exec_info .generated_ids [0 ][
346
346
:gen_len
347
347
] # Because we always run for single input and single batch size
You can’t perform that action at this time.
0 commit comments