diff --git a/byaldi/RAGModel.py b/byaldi/RAGModel.py index 32b66bf..401d8ad 100644 --- a/byaldi/RAGModel.py +++ b/byaldi/RAGModel.py @@ -45,6 +45,7 @@ def from_pretrained( index_root: str = ".byaldi", device: str = "cuda", verbose: int = 1, + cache_dir: Optional[str] = "/cache_dir/models", ): """Load a ColPali model from a pre-trained checkpoint. @@ -61,6 +62,7 @@ def from_pretrained( index_root=index_root, device=device, verbose=verbose, + cache_dir=cache_dir, ) return instance diff --git a/byaldi/colpali.py b/byaldi/colpali.py index cc11dcb..139ab30 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -76,6 +76,7 @@ def __init__( else None ), token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), + cache_dir=kwargs.get("cache_dir", None), ) elif "colqwen2" in pretrained_model_name_or_path.lower(): self.model = ColQwen2.from_pretrained(