File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed
tests/transformers/models Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change 22
22
from QEfficient .utils .device_utils import get_available_device_id
23
23
from QEfficient .utils .run_utils import ApiRunner
24
24
25
+ extrenal_models = {"hpcai-tech/grok-1" }
25
26
test_models_qaic = [
26
27
"TinyLlama/TinyLlama-1.1B-Chat-v1.0" ,
27
28
"gpt2" ,
@@ -80,10 +81,13 @@ def load_causal_lm_model(model_config):
80
81
num_hidden_layers = model_config ["n_layer" ],
81
82
attn_implementation = "eager" ,
82
83
low_cpu_mem_usage = False ,
83
- trust_remote_code = True ,
84
- ) # Run models for single layers only
84
+ trust_remote_code = model_config ["model_name" ] in extrenal_models ,
85
+ )
86
+ # Convert to FP32 if model is in BF16
87
+ if getattr (model_hf .config , "torch_dtype" , None ) == torch .bfloat16 :
88
+ model_hf = model_hf .to (torch .float32 )
89
+
85
90
params = sum (p .numel () for p in model_hf .parameters ())
86
- model_hf .to (torch .float32 )
87
91
model_hf .eval ()
88
92
return model_hf , params
89
93
You can’t perform that action at this time.
0 commit comments