Skip to content

Commit

Permalink
model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Alessandro Sordoni committed Nov 10, 2024
1 parent 94f1d42 commit 144cbc9
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions mttl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def model_loader_helper(
raise ValueError("Specify either 'load_in_4bit' or 'load_in_8bit' or neither.")

from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
BitsAndBytesConfig,
LlamaForCausalLM,
PreTrainedModel,
Expand Down Expand Up @@ -144,17 +145,23 @@ def model_loader_helper(

model_name = os.environ["PHI_PATH"]
logger.info(f"Loading phi-2 model from {os.environ['PHI_PATH']}")
try:
model_object = AutoModel.from_pretrained(
model_name,
device_map=device_map,
trust_remote_code=True,
attn_implementation=attn_implementation,
quantization_config=bnb_config,
torch_dtype=torch_dtype,
)
except Exception as e:
logger.error(f"loading model: {e}")
else:
model_object = None
for klass in [AutoModelForCausalLM, AutoModelForSeq2SeqLM]:
try:
model_object = klass.from_pretrained(
model_name,
device_map=device_map,
trust_remote_code=True,
attn_implementation=attn_implementation,
quantization_config=bnb_config,
torch_dtype=torch_dtype,
)
break
except:
continue
if model_object is None:
raise ValueError(f"Couldn't load {model_name}!")

if bnb_config is not None:
model_object = prepare_model_for_kbit_training(model_object)
Expand Down

0 comments on commit 144cbc9

Please sign in to comment.