Skip to content

Commit 38989e9

Browse files
authored
Removed device IDs from the test (#389)
Signed-off-by: Rishin Raj <[email protected]>
1 parent 5c471b6 commit 38989e9

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

tests/transformers/spd/test_pld_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def test_pld_spec_decode_inference(
262262
num_speculative_tokens=num_speculative_tokens,
263263
)
264264
# init qaic session
265-
target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group)
265+
target_model_session = QAICInferenceSession(target_model_qpc_path)
266266
draft_model_session = None
267267

268268
# skip inputs/outputs buffers
@@ -453,7 +453,7 @@ def test_pld_spec_decode_inference(
453453
del draft_model_session
454454
generated_ids = np.asarray(generated_ids[0]).flatten()
455455
gen_len = generated_ids.shape[0]
456-
exec_info = target_model.generate(tokenizer, Constants.INPUT_STR, device_group)
456+
exec_info = target_model.generate(tokenizer, Constants.INPUT_STR)
457457
cloud_ai_100_tokens = exec_info.generated_ids[0][
458458
:gen_len
459459
] # Because we always run for single input and single batch size

tests/transformers/spd/test_spd_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ def test_spec_decode_inference(
157157
full_batch_size=full_batch_size,
158158
)
159159
# init qaic session
160-
target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=device_group)
161-
draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=device_group)
160+
target_model_session = QAICInferenceSession(target_model_qpc_path)
161+
draft_model_session = QAICInferenceSession(draft_model_qpc_path)
162162

163163
# skip inputs/outputs buffers
164164
target_model_session.skip_buffers(set([x for x in target_model_session.input_names if x.startswith("past_")]))
@@ -341,7 +341,7 @@ def test_spec_decode_inference(
341341
del draft_model_session
342342
generated_ids = np.asarray(generated_ids[0]).flatten()
343343
gen_len = generated_ids.shape[0]
344-
exec_info = draft_model.generate(tokenizer, Constants.INPUT_STR, device_group)
344+
exec_info = draft_model.generate(tokenizer, Constants.INPUT_STR)
345345
cloud_ai_100_tokens = exec_info.generated_ids[0][
346346
:gen_len
347347
] # Because we always run for single input and single batch size

0 commit comments

Comments
 (0)