diff --git a/chandra/model/hf.py b/chandra/model/hf.py index 6be4e47..7d37ff7 100644 --- a/chandra/model/hf.py +++ b/chandra/model/hf.py @@ -1,5 +1,7 @@ from typing import List +import torch + from chandra.model.schema import BatchInputItem, GenerationResult from chandra.model.util import scale_to_fit from chandra.prompts import PROMPT_MAPPING @@ -33,7 +35,18 @@ def generate_hf( return_tensors="pt", padding_side="left", ) - inputs = inputs.to("cuda") + # Auto-detect device: MPS for Apple Silicon, CUDA for NVIDIA, else CPU + if settings.TORCH_DEVICE: + device = settings.TORCH_DEVICE + elif hasattr(model, 'device'): + device = model.device + elif torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + inputs = inputs.to(device) # Inference: Generation of the output generated_ids = model.generate(**inputs, max_new_tokens=max_output_tokens)