diff --git a/README.md b/README.md index 6e533f2..a0416d6 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Byaldi is [RAGatouille](https://github.com/answerdotai/ragatouille)'s mini siste First, a warning: This is a pre-release library, using uncompressed indexes and lacking other kinds of refinements. -Currently, we support all models supported by the underlying [colpali-engine](https://github.com/illuin-tech/colpali), including the newer, and better, ColQwen2 checkpoints, such as `vidore/colqwen2-v1.0`. Broadly, the aim is for byaldi to support all ColVLM models. +Currently, we support all models supported by the underlying [colpali-engine](https://github.com/illuin-tech/colpali), including the newer, and better, ColQwen2 checkpoints, such as `vidore/colqwen2-v1.0`. You can also use `byaldi` to leverage ColSmol models if you have hardware requirements (`vidore/colSmol-256M`, `vidore/colSmol-500M`). Broadly, the aim is for `byaldi` to support all ColVLM models. Additional backends will be supported in future updates. As byaldi exists to facilitate the adoption of multi-modal retrievers, we intend to also add support for models such as [VisRAG](https://github.com/openbmb/visrag). diff --git a/byaldi/RAGModel.py b/byaldi/RAGModel.py index 32b66bf..4c0df48 100644 --- a/byaldi/RAGModel.py +++ b/byaldi/RAGModel.py @@ -4,7 +4,6 @@ from PIL import Image from byaldi.colpali import ColPaliModel - from byaldi.objects import Result # Optional langchain integration diff --git a/byaldi/colpali.py b/byaldi/colpali.py index cc11dcb..2db9a6b 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -7,7 +7,14 @@ import srsly import torch -from colpali_engine.models import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor +from colpali_engine.models import ( + ColIdefics3, + ColIdefics3Processor, + ColPali, + ColPaliProcessor, + ColQwen2, + ColQwen2Processor, +) from pdf2image import convert_from_path from PIL import Image @@ -35,6 +42,7 @@ def __init__( if ( "colpali" not in pretrained_model_name_or_path.lower() and "colqwen2" not in pretrained_model_name_or_path.lower() + and "colsmol" not in pretrained_model_name_or_path.lower() ): raise ValueError( "This pre-release version of Byaldi only supports ColPali and ColQwen2 for now. Incorrect model name specified." @@ -89,6 +97,18 @@ def __init__( ), token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), ) + elif "colsmol" in pretrained_model_name_or_path.lower(): + self.model = ColIdefics3.from_pretrained( + self.pretrained_model_name_or_path, + torch_dtype=torch.bfloat16, + device_map=( + "cuda" + if device == "cuda" + or (isinstance(device, torch.device) and device.type == "cuda") + else None + ), + token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), + ) self.model = self.model.eval() if "colpali" in pretrained_model_name_or_path.lower(): @@ -107,6 +127,14 @@ def __init__( token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), ), ) + elif "colsmol" in pretrained_model_name_or_path.lower(): + self.processor = cast( + ColIdefics3Processor, + ColIdefics3Processor.from_pretrained( + self.pretrained_model_name_or_path, + token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), + ), + ) self.device = device if device != "cuda" and not ( diff --git a/byaldi/integrations/__init__.py b/byaldi/integrations/__init__.py index 5841288..f2cf6cd 100644 --- a/byaldi/integrations/__init__.py +++ b/byaldi/integrations/__init__.py @@ -1,7 +1,7 @@ _all__ = [] try: - from byaldi.integrations._langchain import ByaldiLangChainRetriever + from byaldi.integrations._langchain import ByaldiLangChainRetriever # noqa: F401 _all__.append("ByaldiLangChainRetriever") except ImportError: diff --git a/pyproject.toml b/pyproject.toml index bf1d624..fba2080 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,14 +28,14 @@ maintainers = [ ] dependencies = [ - "colpali-engine>=0.3.4,<0.4.0", + "colpali-engine>=0.3.7,<0.4.0", "ml-dtypes", "mteb==1.6.35", "ninja", "pdf2image", "srsly", "torch", - "transformers>=4.42.0", + "transformers>=4.47.0", ] [project.optional-dependencies] diff --git a/tests/test_colpali.py b/tests/test_colpali.py index f8e85ae..d016a20 100644 --- a/tests/test_colpali.py +++ b/tests/test_colpali.py @@ -12,7 +12,7 @@ def colpali_rag_model() -> Generator[RAGMultiModalModel, None, None]: device = get_torch_device("auto") print(f"Using device: {device}") - yield RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2", device=device) + yield RAGMultiModalModel.from_pretrained("vidore/colpali-v1.3", device=device) tear_down_torch() diff --git a/tests/test_colqwen.py b/tests/test_colqwen.py index 6a68acc..459106a 100644 --- a/tests/test_colqwen.py +++ b/tests/test_colqwen.py @@ -12,7 +12,7 @@ def colqwen_rag_model() -> Generator[RAGMultiModalModel, None, None]: device = get_torch_device("auto") print(f"Using device: {device}") - yield RAGMultiModalModel.from_pretrained("vidore/colqwen2-v0.1", device=device) + yield RAGMultiModalModel.from_pretrained("vidore/colqwen2-v1.0", device=device) tear_down_torch() diff --git a/tests/test_colsmol.py b/tests/test_colsmol.py new file mode 100644 index 0000000..56d0bec --- /dev/null +++ b/tests/test_colsmol.py @@ -0,0 +1,23 @@ +from typing import Generator + +import pytest +from colpali_engine.models import ColIdefics3 +from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch + +from byaldi import RAGMultiModalModel +from byaldi.colpali import ColPaliModel + + +@pytest.fixture(scope="module") +def colsmol_rag_model() -> Generator[RAGMultiModalModel, None, None]: + device = get_torch_device("auto") + print(f"Using device: {device}") + yield RAGMultiModalModel.from_pretrained("vidore/colSmol-256M", device=device) + tear_down_torch() + + +@pytest.mark.slow +def test_load_colqwen_from_pretrained(colsmol_rag_model: RAGMultiModalModel): + assert isinstance(colsmol_rag_model, RAGMultiModalModel) + assert isinstance(colsmol_rag_model.model, ColPaliModel) + assert isinstance(colsmol_rag_model.model.model, ColIdefics3)