Skip to content

Commit 3d3b21d

Browse files
committed
Replace deprecated torch_dtype with dtype parameter
- Update all from_pretrained() calls to use 'dtype' instead of 'torch_dtype' - Fixes deprecation warning from transformers library - Changes in HFPipelineBasedInferenceEngine, LlavaInferenceEngine, and HFPeftInferenceEngine Signed-off-by: Yoav Katz <katz@il.ibm.com>
1 parent 2a73943 commit 3d3b21d

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

src/unitxt/inference.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -728,9 +728,9 @@ def _get_model_args(self) -> Dict[str, Any]:
728728
args["quantization_config"] = quantization_config
729729
elif self.use_fp16:
730730
if self.device == torch.device("mps"):
731-
args["torch_dtype"] = torch.float16
731+
args["dtype"] = torch.float16
732732
else:
733-
args["torch_dtype"] = torch.bfloat16
733+
args["dtype"] = torch.bfloat16
734734

735735
# We do this, because in some cases, using device:auto will offload some weights to the cpu
736736
# (even though the model might *just* fit to a single gpu), even if there is a gpu available, and this will
@@ -937,7 +937,7 @@ def _init_model(self):
937937

938938
self.model = LlavaForConditionalGeneration.from_pretrained(
939939
self.model_name,
940-
torch_dtype=self._get_torch_dtype(),
940+
dtype=self._get_torch_dtype(),
941941
low_cpu_mem_usage=self.low_cpu_mem_usage,
942942
device_map=self.device_map,
943943
)
@@ -1108,7 +1108,7 @@ def _init_model(self):
11081108
trust_remote_code=True,
11091109
device_map=self.device_map,
11101110
low_cpu_mem_usage=self.low_cpu_mem_usage,
1111-
torch_dtype=self._get_torch_dtype(),
1111+
dtype=self._get_torch_dtype(),
11121112
)
11131113
self.model = self.model.to(
11141114
dtype=self._get_torch_dtype()
@@ -1197,9 +1197,9 @@ def _get_model_args(self) -> Dict[str, Any]:
11971197
args["quantization_config"] = quantization_config
11981198
elif self.use_fp16:
11991199
if self.device == torch.device("mps"):
1200-
args["torch_dtype"] = torch.float16
1200+
args["dtype"] = torch.float16
12011201
else:
1202-
args["torch_dtype"] = torch.bfloat16
1202+
args["dtype"] = torch.bfloat16
12031203

12041204
# We do this, because in some cases, using device:auto will offload some weights to the cpu
12051205
# (even though the model might *just* fit to a single gpu), even if there is a gpu available, and this will

0 commit comments

Comments
 (0)