Skip to content
10 changes: 10 additions & 0 deletions src/huggingface_hub/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ def append_to_message(self, additional_message: str) -> None:
"""Append additional information to the `HfHubHTTPError` initial message."""
self.args = (self.args[0] + additional_message,) + self.args[1:]

def __reduce_ex__(self, protocol):
"""Fix pickling of Exception subclass with kwargs"""
args = (str(self),)
kwargs = {"response": self.response, "server_message": self.server_message}
return self._from_args, (args, kwargs)

@classmethod
def _from_args(cls, args, kwargs):
return cls(*args, **kwargs)


# INFERENCE CLIENT ERRORS

Expand Down
92 changes: 83 additions & 9 deletions src/huggingface_hub/hf_file_system.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
import re
import tempfile
import threading
from collections import deque
from contextlib import ExitStack
from copy import deepcopy
from dataclasses import dataclass, field
from datetime import datetime
from hashlib import md5
from itertools import chain
from pathlib import Path
from typing import Any, Iterator, NoReturn, Optional, Union
Expand Down Expand Up @@ -56,7 +59,59 @@ def unresolve(self) -> str:
return f"{repo_path}/{self.path_in_repo}".rstrip("/")


class HfFileSystem(fsspec.AbstractFileSystem):
# We need to improve fsspec.spec._Cached which is AbstractFileSystem's metaclass
_cached_base: Any = type(fsspec.AbstractFileSystem)


class _Cached(_cached_base):
"""
Metaclass for caching HfFileSystem instances according to the args.

This creates an additional reference to the filesystem, which prevents the
filesystem from being garbage collected when all *user* references go away.
A call to the :meth:`AbstractFileSystem.clear_instance_cache` must *also*
be made for a filesystem instance to be garbage collected.

This is a slightly modified version of `fsspec.spec._Cached` to improve it.
In particular in `_tokenize` the pid isn't taken into account for the
`fs_token` used to identify cached instances. The `fs_token` logic is also
robust to defaults values and the order of the args. Finally new instances
reuse the states from sister instances in the main thread.
"""

def __init__(cls, *args, **kwargs):
super().__init__(*args, **kwargs)
# Note: we intentionally create a reference here, to avoid garbage
# collecting instances when all other references are gone. To really
# delete a FileSystem, the cache must be cleared.
cls._cache = {}

def __call__(cls, *args, **kwargs):
skip = kwargs.pop("skip_instance_cache", False)
fs_token = cls._tokenize(cls, threading.get_ident(), *args, **kwargs)
fs_token_main_thread = cls._tokenize(cls, threading.main_thread().ident, *args, **kwargs)
if not skip and cls.cachable and fs_token in cls._cache:
# reuse cached instance
cls._latest = fs_token
return cls._cache[fs_token]
else:
# create new instance
obj = type.__call__(cls, *args, **kwargs)
if not skip and cls.cachable and fs_token_main_thread in cls._cache:
# reuse the cache from the main thread instance in the new instance
instance_state = cls._cache[fs_token_main_thread]._get_instance_state()
for attr, state_value in instance_state.items():
setattr(obj, attr, state_value)
obj._fs_token_ = fs_token
obj.storage_args = args
obj.storage_options = kwargs
if cls.cachable and not skip:
cls._latest = fs_token
cls._cache[fs_token] = obj
return obj


class HfFileSystem(fsspec.AbstractFileSystem, metaclass=_Cached):
"""
Access a remote Hugging Face Hub repository as if were a local file system.

Expand Down Expand Up @@ -119,6 +174,22 @@ def __init__(
# Maps parent directory path to path infos
self.dircache: dict[str, list[dict[str, Any]]] = {}

@classmethod
def _tokenize(cls, threading_ident: int, *args, **kwargs) -> str:
"""Deterministic token for caching"""
# make fs_token robust to default values and to kwargs order
kwargs["endpoint"] = kwargs.get("endpoint") or constants.ENDPOINT
kwargs["token"] = kwargs.get("token")
kwargs = {key: kwargs[key] for key in sorted(kwargs)}
# contrary to fsspec, we don't include pid here
tokenize_args = (cls, threading_ident, args, kwargs)
try:
h = md5(str(tokenize_args).encode())
except ValueError:
# FIPS systems: https://github.com/fsspec/filesystem_spec/issues/380
h = md5(str(tokenize_args).encode(), usedforsecurity=False)
return h.hexdigest()

def _repo_and_revision_exist(
self, repo_type: str, repo_id: str, revision: Optional[str]
) -> tuple[bool, Optional[Exception]]:
Expand Down Expand Up @@ -931,17 +1002,20 @@ def start_transaction(self):
raise NotImplementedError("Transactional commits are not supported.")

def __reduce__(self):
# re-populate the instance cache at HfFileSystem._cache and re-populate the cache attributes of every instance
# re-populate the instance cache at HfFileSystem._cache and re-populate the state of every instance
return make_instance, (
type(self),
self.storage_args,
self.storage_options,
{
"dircache": self.dircache,
"_repo_and_revision_exists_cache": self._repo_and_revision_exists_cache,
},
self._get_instance_state(),
)

def _get_instance_state(self):
return {
"dircache": deepcopy(self.dircache),
"_repo_and_revision_exists_cache": deepcopy(self._repo_and_revision_exists_cache),
}


class HfFileSystemFile(fsspec.spec.AbstractBufferedFile):
def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs):
Expand Down Expand Up @@ -1178,8 +1252,8 @@ def _partial_read(response: httpx.Response, length: int = -1) -> bytes:
return bytes(buf) # may be < length if response ended


def make_instance(cls, args, kwargs, instance_cache_attributes_dict):
def make_instance(cls, args, kwargs, instance_state):
fs = cls(*args, **kwargs)
for attr, cached_value in instance_cache_attributes_dict.items():
setattr(fs, attr, cached_value)
for attr, state_value in instance_state.items():
setattr(fs, attr, state_value)
return fs
38 changes: 38 additions & 0 deletions tests/test_hf_file_system.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import copy
import datetime
import io
import multiprocessing
import multiprocessing.pool
import os
import pickle
import tempfile
Expand Down Expand Up @@ -644,6 +646,42 @@ def test_exists_after_repo_deletion():
assert not hffs.exists(repo_id, refresh=True)


def _get_fs_token_and_dircache(fs):
fs = HfFileSystem(endpoint=fs.endpoint, token=fs.token)
return fs._fs_token, fs.dircache


def test_cache():
HfFileSystem.clear_instance_cache()
fs = HfFileSystem()
fs.dircache = {"dummy": []}

assert HfFileSystem() is fs
assert HfFileSystem(endpoint=constants.ENDPOINT) is fs
assert HfFileSystem(token=None, endpoint=constants.ENDPOINT) is fs

another_fs = HfFileSystem(endpoint="something-else")
assert another_fs is not fs
assert another_fs.dircache != fs.dircache

with multiprocessing.get_context("spawn").Pool() as pool:
(fs_token, dircache), (_, another_dircache) = pool.map(_get_fs_token_and_dircache, [fs, another_fs])
assert dircache == fs.dircache
assert another_dircache != fs.dircache

if os.name != "nt": # "fork" is unavailable on windows
with multiprocessing.get_context("fork").Pool() as pool:
(fs_token, dircache), (_, another_dircache) = pool.map(_get_fs_token_and_dircache, [fs, another_fs])
assert dircache == fs.dircache
assert another_dircache != fs.dircache

with multiprocessing.pool.ThreadPool() as pool:
(fs_token, dircache), (_, another_dircache) = pool.map(_get_fs_token_and_dircache, [fs, another_fs])
assert dircache == fs.dircache
assert another_dircache != fs.dircache
assert fs_token != fs._fs_token # use a different instance for thread safety


@with_production_testing
def test_hf_file_system_file_can_handle_gzipped_file():
"""Test that HfFileSystemStreamFile.read() can handle gzipped files."""
Expand Down
Loading