22
22
from QEfficient .utils .device_utils import get_available_device_id
23
23
from QEfficient .utils .run_utils import ApiRunner
24
24
25
+ extrenal_models = {"hpcai-tech/grok-1" }
25
26
test_models_qaic = [
26
27
"TinyLlama/TinyLlama-1.1B-Chat-v1.0" ,
27
28
"gpt2" ,
61
62
]
62
63
63
64
64
- def load_causal_lm_model (model_config , model_name ):
65
+ def load_causal_lm_model (model_config ):
65
66
"""
66
67
Function to load model from huggingface and transform to KV model
67
68
--------
@@ -80,11 +81,13 @@ def load_causal_lm_model(model_config, model_name):
80
81
num_hidden_layers = model_config ["n_layer" ],
81
82
attn_implementation = "eager" ,
82
83
low_cpu_mem_usage = False ,
83
- trust_remote_code = True if model_name == "hpcai-tech/grok-1" else False ,
84
- ) # Run models for single layers only
84
+ trust_remote_code = model_config ["model_name" ] in extrenal_models ,
85
+ )
86
+ # Convert to FP32 if model is in BF16
87
+ if getattr (model_hf .config , "torch_dtype" , None ) == torch .bfloat16 :
88
+ model_hf = model_hf .to (torch .float32 )
89
+
85
90
params = sum (p .numel () for p in model_hf .parameters ())
86
- if model_name == "hpcai-tech/grok-1" :
87
- model_hf .to (torch .float32 )
88
91
model_hf .eval ()
89
92
return model_hf , params
90
93
@@ -111,7 +114,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
111
114
model_config = {"model_name" : model_name }
112
115
model_config ["n_layer" ] = n_layer
113
116
114
- model_hf , _ = load_causal_lm_model (model_config , model_name )
117
+ model_hf , _ = load_causal_lm_model (model_config )
115
118
116
119
tokenizer = load_hf_tokenizer (pretrained_model_name_or_path = model_name )
117
120
config = model_hf .config
@@ -172,7 +175,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
172
175
if prefill_only is not None :
173
176
return
174
177
# testing for CB models
175
- model_hf , _ = load_causal_lm_model (model_config , model_name )
178
+ model_hf , _ = load_causal_lm_model (model_config )
176
179
full_batch_size = 4
177
180
fbs_prompts = Constants .INPUT_STR * 4
178
181
api_runner = ApiRunner (
0 commit comments