Skip to content

Commit 39ccb4b

Browse files
authored
Merge pull request #9 from abukhoy/pr-373
trust_remote_code enabled for grok1 only
2 parents 725da03 + 312de24 commit 39ccb4b

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

tests/transformers/models/test_causal_lm_models.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from QEfficient.utils.device_utils import get_available_device_id
2323
from QEfficient.utils.run_utils import ApiRunner
2424

25+
extrenal_models = {"hpcai-tech/grok-1"}
2526
test_models_qaic = [
2627
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
2728
"gpt2",
@@ -80,10 +81,13 @@ def load_causal_lm_model(model_config):
8081
num_hidden_layers=model_config["n_layer"],
8182
attn_implementation="eager",
8283
low_cpu_mem_usage=False,
83-
trust_remote_code=True,
84-
) # Run models for single layers only
84+
trust_remote_code=model_config["model_name"] in extrenal_models,
85+
)
86+
# Convert to FP32 if model is in BF16
87+
if getattr(model_hf.config, "torch_dtype", None) == torch.bfloat16:
88+
model_hf = model_hf.to(torch.float32)
89+
8590
params = sum(p.numel() for p in model_hf.parameters())
86-
model_hf.to(torch.float32)
8791
model_hf.eval()
8892
return model_hf, params
8993

0 commit comments

Comments
 (0)