diff --git a/src/huggingface_hub/errors.py b/src/huggingface_hub/errors.py index f429db7cc4..1c5e71b569 100644 --- a/src/huggingface_hub/errors.py +++ b/src/huggingface_hub/errors.py @@ -84,6 +84,10 @@ def append_to_message(self, additional_message: str) -> None: """Append additional information to the `HfHubHTTPError` initial message.""" self.args = (self.args[0] + additional_message,) + self.args[1:] + def __reduce_ex__(self, protocol): + """Fix pickling of Exception subclass with kwargs. We need to override __reduce_ex__ of the parent class""" + return (self.__class__, (str(self),), {"response": self.response, "server_message": self.server_message}) + # INFERENCE CLIENT ERRORS diff --git a/src/huggingface_hub/hf_file_system.py b/src/huggingface_hub/hf_file_system.py index 57fde3bbb2..614eb6cc15 100644 --- a/src/huggingface_hub/hf_file_system.py +++ b/src/huggingface_hub/hf_file_system.py @@ -1,8 +1,10 @@ import os import re import tempfile +import threading from collections import deque from contextlib import ExitStack +from copy import deepcopy from dataclasses import dataclass, field from datetime import datetime from itertools import chain @@ -21,6 +23,7 @@ from .file_download import hf_hub_url, http_get from .hf_api import HfApi, LastCommitInfo, RepoFile from .utils import HFValidationError, hf_raise_for_status, http_backoff, http_stream_backoff +from .utils.insecure_hashlib import md5 # Regex used to match special revisions with "/" in them (see #1710) @@ -56,7 +59,61 @@ def unresolve(self) -> str: return f"{repo_path}/{self.path_in_repo}".rstrip("/") -class HfFileSystem(fsspec.AbstractFileSystem): +# We need to improve fsspec.spec._Cached which is AbstractFileSystem's metaclass +_cached_base: Any = type(fsspec.AbstractFileSystem) + + +class _Cached(_cached_base): + """ + Metaclass for caching HfFileSystem instances according to the args. + + This creates an additional reference to the filesystem, which prevents the + filesystem from being garbage collected when all *user* references go away. + A call to the :meth:`AbstractFileSystem.clear_instance_cache` must *also* + be made for a filesystem instance to be garbage collected. + + This is a slightly modified version of `fsspec.spec._Cached` to improve it. + In particular in `_tokenize` the pid isn't taken into account for the + `fs_token` used to identify cached instances. The `fs_token` logic is also + robust to defaults values and the order of the args. Finally new instances + reuse the states from sister instances in the main thread. + """ + + def __init__(cls, *args, **kwargs): + # Hack: override https://github.com/fsspec/filesystem_spec/blob/dcb167e8f50e6273d4cfdfc4cab8fc5aa4c958bf/fsspec/spec.py#L53 + super().__init__(*args, **kwargs) + # Note: we intentionally create a reference here, to avoid garbage + # collecting instances when all other references are gone. To really + # delete a FileSystem, the cache must be cleared. + cls._cache = {} + + def __call__(cls, *args, **kwargs): + # Hack: override https://github.com/fsspec/filesystem_spec/blob/dcb167e8f50e6273d4cfdfc4cab8fc5aa4c958bf/fsspec/spec.py#L65 + skip = kwargs.pop("skip_instance_cache", False) + fs_token = cls._tokenize(cls, threading.get_ident(), *args, **kwargs) + fs_token_main_thread = cls._tokenize(cls, threading.main_thread().ident, *args, **kwargs) + if not skip and cls.cachable and fs_token in cls._cache: + # reuse cached instance + cls._latest = fs_token + return cls._cache[fs_token] + else: + # create new instance + obj = type.__call__(cls, *args, **kwargs) + if not skip and cls.cachable and fs_token_main_thread in cls._cache: + # reuse the cache from the main thread instance in the new instance + instance_state = cls._cache[fs_token_main_thread]._get_instance_state() + for attr, state_value in instance_state.items(): + setattr(obj, attr, state_value) + obj._fs_token_ = fs_token + obj.storage_args = args + obj.storage_options = kwargs + if cls.cachable and not skip: + cls._latest = fs_token + cls._cache[fs_token] = obj + return obj + + +class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached): """ Access a remote Hugging Face Hub repository as if were a local file system. @@ -119,6 +176,18 @@ def __init__( # Maps parent directory path to path infos self.dircache: dict[str, list[dict[str, Any]]] = {} + @classmethod + def _tokenize(cls, threading_ident: int, *args, **kwargs) -> str: + """Deterministic token for caching""" + # make fs_token robust to default values and to kwargs order + kwargs["endpoint"] = kwargs.get("endpoint") or constants.ENDPOINT + kwargs["token"] = kwargs.get("token") + kwargs = {key: kwargs[key] for key in sorted(kwargs)} + # contrary to fsspec, we don't include pid here + tokenize_args = (cls, threading_ident, args, kwargs) + h = md5(str(tokenize_args).encode()) + return h.hexdigest() + def _repo_and_revision_exist( self, repo_type: str, repo_id: str, revision: Optional[str] ) -> tuple[bool, Optional[Exception]]: @@ -931,17 +1000,20 @@ def start_transaction(self): raise NotImplementedError("Transactional commits are not supported.") def __reduce__(self): - # re-populate the instance cache at HfFileSystem._cache and re-populate the cache attributes of every instance + # re-populate the instance cache at HfFileSystem._cache and re-populate the state of every instance return make_instance, ( type(self), self.storage_args, self.storage_options, - { - "dircache": self.dircache, - "_repo_and_revision_exists_cache": self._repo_and_revision_exists_cache, - }, + self._get_instance_state(), ) + def _get_instance_state(self): + return { + "dircache": deepcopy(self.dircache), + "_repo_and_revision_exists_cache": deepcopy(self._repo_and_revision_exists_cache), + } + class HfFileSystemFile(fsspec.spec.AbstractBufferedFile): def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs): @@ -1178,8 +1250,8 @@ def _partial_read(response: httpx.Response, length: int = -1) -> bytes: return bytes(buf) # may be < length if response ended -def make_instance(cls, args, kwargs, instance_cache_attributes_dict): +def make_instance(cls, args, kwargs, instance_state): fs = cls(*args, **kwargs) - for attr, cached_value in instance_cache_attributes_dict.items(): - setattr(fs, attr, cached_value) + for attr, state_value in instance_state.items(): + setattr(fs, attr, state_value) return fs diff --git a/tests/test_hf_file_system.py b/tests/test_hf_file_system.py index 65a189f9d4..1c22ab9d29 100644 --- a/tests/test_hf_file_system.py +++ b/tests/test_hf_file_system.py @@ -1,6 +1,8 @@ import copy import datetime import io +import multiprocessing +import multiprocessing.pool import os import pickle import tempfile @@ -644,6 +646,42 @@ def test_exists_after_repo_deletion(): assert not hffs.exists(repo_id, refresh=True) +def _get_fs_token_and_dircache(fs): + fs = HfFileSystem(endpoint=fs.endpoint, token=fs.token) + return fs._fs_token, fs.dircache + + +def test_cache(): + HfFileSystem.clear_instance_cache() + fs = HfFileSystem() + fs.dircache = {"dummy": []} + + assert HfFileSystem() is fs + assert HfFileSystem(endpoint=constants.ENDPOINT) is fs + assert HfFileSystem(token=None, endpoint=constants.ENDPOINT) is fs + + another_fs = HfFileSystem(endpoint="something-else") + assert another_fs is not fs + assert another_fs.dircache != fs.dircache + + with multiprocessing.get_context("spawn").Pool() as pool: + (fs_token, dircache), (_, another_dircache) = pool.map(_get_fs_token_and_dircache, [fs, another_fs]) + assert dircache == fs.dircache + assert another_dircache != fs.dircache + + if os.name != "nt": # "fork" is unavailable on windows + with multiprocessing.get_context("fork").Pool() as pool: + (fs_token, dircache), (_, another_dircache) = pool.map(_get_fs_token_and_dircache, [fs, another_fs]) + assert dircache == fs.dircache + assert another_dircache != fs.dircache + + with multiprocessing.pool.ThreadPool() as pool: + (fs_token, dircache), (_, another_dircache) = pool.map(_get_fs_token_and_dircache, [fs, another_fs]) + assert dircache == fs.dircache + assert another_dircache != fs.dircache + assert fs_token != fs._fs_token # use a different instance for thread safety + + @with_production_testing def test_hf_file_system_file_can_handle_gzipped_file(): """Test that HfFileSystemStreamFile.read() can handle gzipped files."""