diff --git a/byaldi/colpali.py b/byaldi/colpali.py index cc11dcb..91480e6 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -87,6 +87,10 @@ def __init__( or (isinstance(device, torch.device) and device.type == "cuda") else None ), + #Fix: Use attn_implementation='eager' for MPS compatibility + attn_implementation = "eager" if device == "mps" or ( + isinstance(device, torch.device) and device.type == "mps" + ) else None, token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), ) self.model = self.model.eval()