From 5f31192e02e5b9bf2654a631ff37d86788b6833d Mon Sep 17 00:00:00 2001 From: Murshid Orion MBP Date: Sat, 15 Feb 2025 23:24:26 +0530 Subject: [PATCH 1/2] Fix: Use attn_implementation='eager' for MPS compatibility --- byaldi/colpali.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/byaldi/colpali.py b/byaldi/colpali.py index cc11dcb..bcf3bf7 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -87,6 +87,10 @@ def __init__( or (isinstance(device, torch.device) and device.type == "cuda") else None ), + #Fix: Use attn_implementation='eager' for MPS compatibility + attn_implementation = "eager" if device == "mps" or ( + isinstance(device, torch.device) and device.type == "mps" + ) else None token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), ) self.model = self.model.eval() From 93804d65dd4be69c64c67a23bd7d46be5bd10a68 Mon Sep 17 00:00:00 2001 From: Murshid Orion MBP Date: Sat, 15 Feb 2025 23:54:07 +0530 Subject: [PATCH 2/2] Fix: Correct the comma missing issue in the pull request --- byaldi/colpali.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/byaldi/colpali.py b/byaldi/colpali.py index bcf3bf7..91480e6 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -90,7 +90,7 @@ def __init__( #Fix: Use attn_implementation='eager' for MPS compatibility attn_implementation = "eager" if device == "mps" or ( isinstance(device, torch.device) and device.type == "mps" - ) else None + ) else None, token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), ) self.model = self.model.eval()