diff --git a/byaldi/colpali.py b/byaldi/colpali.py index cc11dcb..40c24fc 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -7,7 +7,7 @@ import srsly import torch -from colpali_engine.models import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor +from colpali_engine.models import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor, ColQwen2_5, ColQwen2_5_Processor from pdf2image import convert_from_path from PIL import Image @@ -35,6 +35,7 @@ def __init__( if ( "colpali" not in pretrained_model_name_or_path.lower() and "colqwen2" not in pretrained_model_name_or_path.lower() + and "colqwen2.5" not in pretrained_model_name_or_path.lower() ): raise ValueError( "This pre-release version of Byaldi only supports ColPali and ColQwen2 for now. Incorrect model name specified." @@ -77,6 +78,18 @@ def __init__( ), token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), ) + elif "colqwen2.5" in pretrained_model_name_or_path.lower(): + self.model = ColQwen2_5.from_pretrained( + self.pretrained_model_name_or_path, + torch_dtype=torch.bfloat16, + device_map=( + "cuda" + if device == "cuda" + or (isinstance(device, torch.device) and device.type == "cuda") + else None + ), + token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), + ) elif "colqwen2" in pretrained_model_name_or_path.lower(): self.model = ColQwen2.from_pretrained( self.pretrained_model_name_or_path, @@ -99,6 +112,14 @@ def __init__( token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), ), ) + elif "colqwen2.5" in pretrained_model_name_or_path.lower(): + self.processor = cast( + ColQwen2_5_Processor, + ColQwen2_5_Processor.from_pretrained( + self.pretrained_model_name_or_path, + token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), + ), + ) elif "colqwen2" in pretrained_model_name_or_path.lower(): self.processor = cast( ColQwen2Processor, @@ -742,3 +763,4 @@ def encode_query(self, query: Union[str, List[str]]) -> torch.Tensor: def get_doc_ids_to_file_names(self): return self.doc_ids_to_file_names +