diff --git a/chandra/model/hf.py b/chandra/model/hf.py index 50aa883..83ae587 100644 --- a/chandra/model/hf.py +++ b/chandra/model/hf.py @@ -1,5 +1,6 @@ from typing import List +import torch from qwen_vl_utils import process_vision_info from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor @@ -28,7 +29,8 @@ def generate_hf( return_tensors="pt", padding_side="left", ) - inputs = inputs.to("cuda") + device = resolve_device(model) + inputs = inputs.to(device) # Inference: Generation of the output generated_ids = model.generate(**inputs, max_new_tokens=max_output_tokens) @@ -83,3 +85,28 @@ def load_model(): processor = Qwen3VLProcessor.from_pretrained(settings.MODEL_CHECKPOINT) model.processor = processor return model + + +def resolve_device(model) -> torch.device: + if settings.TORCH_DEVICE: + return torch.device(settings.TORCH_DEVICE) + + try: + parameter_device = next(model.parameters()).device + if parameter_device.type != "meta": + return parameter_device + except (StopIteration, AttributeError): + pass + + model_device = getattr(model, "device", None) + if isinstance(model_device, torch.device) and model_device.type != "meta": + return model_device + + if torch.cuda.is_available(): + return torch.device("cuda") + + mps_backend = getattr(torch.backends, "mps", None) + if mps_backend and mps_backend.is_available(): + return torch.device("mps") + + return torch.device("cpu")