From 419af0793fb4e3edb707cd02dd2b9deff4684dab Mon Sep 17 00:00:00 2001 From: Mark Seo Date: Mon, 29 Dec 2025 17:18:06 +0900 Subject: [PATCH] fix: add Apple Silicon MPS support for device detection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The generate_hf function had hardcoded "cuda" device, causing AssertionError on non-NVIDIA systems. This change adds automatic device detection with the following priority: 1. TORCH_DEVICE setting (if configured) 2. model.device attribute (if available) 3. CUDA (if torch.cuda.is_available()) 4. MPS (if torch.backends.mps.is_available()) - Apple Silicon 5. CPU (fallback) This enables running chandra on Apple Silicon Macs with Metal acceleration via MPS backend. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- chandra/model/hf.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/chandra/model/hf.py b/chandra/model/hf.py index 6be4e47..7d37ff7 100644 --- a/chandra/model/hf.py +++ b/chandra/model/hf.py @@ -1,5 +1,7 @@ from typing import List +import torch + from chandra.model.schema import BatchInputItem, GenerationResult from chandra.model.util import scale_to_fit from chandra.prompts import PROMPT_MAPPING @@ -33,7 +35,18 @@ def generate_hf( return_tensors="pt", padding_side="left", ) - inputs = inputs.to("cuda") + # Auto-detect device: MPS for Apple Silicon, CUDA for NVIDIA, else CPU + if settings.TORCH_DEVICE: + device = settings.TORCH_DEVICE + elif hasattr(model, 'device'): + device = model.device + elif torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + inputs = inputs.to(device) # Inference: Generation of the output generated_ids = model.generate(**inputs, max_new_tokens=max_output_tokens)