Skip to content

Commit

Permalink
use AutoModel
Browse files Browse the repository at this point in the history
  • Loading branch information
Alessandro Sordoni committed Nov 10, 2024
1 parent 28a1947 commit 94f1d42
Showing 1 changed file with 2 additions and 11 deletions.
13 changes: 2 additions & 11 deletions mttl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def model_loader_helper(
raise ValueError("Specify either 'load_in_4bit' or 'load_in_8bit' or neither.")

from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModel,
BitsAndBytesConfig,
LlamaForCausalLM,
PreTrainedModel,
Expand Down Expand Up @@ -146,7 +145,7 @@ def model_loader_helper(
model_name = os.environ["PHI_PATH"]
logger.info(f"Loading phi-2 model from {os.environ['PHI_PATH']}")
try:
model_object = AutoModelForCausalLM.from_pretrained(
model_object = AutoModel.from_pretrained(
model_name,
device_map=device_map,
trust_remote_code=True,
Expand All @@ -156,14 +155,6 @@ def model_loader_helper(
)
except Exception as e:
logger.error(f"loading model: {e}")
model_object = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
device_map=device_map,
trust_remote_code=True,
attn_implementation=attn_implementation,
quantization_config=bnb_config,
torch_dtype=torch_dtype,
)

if bnb_config is not None:
model_object = prepare_model_for_kbit_training(model_object)
Expand Down

0 comments on commit 94f1d42

Please sign in to comment.