Skip to content

Commit 6adcdf1

Browse files
committed
Fix failed tests
1 parent 1292211 commit 6adcdf1

File tree

2 files changed

+66
-52
lines changed

2 files changed

+66
-52
lines changed

src/huggingface_hub/hf_api.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -236,64 +236,72 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> tu
236236
"""
237237
input_hf_id = hf_id
238238

239-
hub_url = hub_url if hub_url is not None else constants.ENDPOINT
240-
is_hf_url = hub_url in hf_id and "@" not in hf_id
241-
242-
hub_url = hub_url.rstrip("/")
243-
if hf_id.startswith(hub_url):
244-
hf_id = hf_id[len(hub_url):].lstrip("/")
245-
elif hf_id.startswith(hub_url.replace("https://", "").replace("http://", "")):
246-
# Handle urls like "localhost:8080/hf/model/xxx"
247-
# https://github.com/huggingface/huggingface_hub/issues/3494
248-
hf_id = hf_id[len(hub_url.replace("https://", "").replace("http://", "")):].lstrip("/")
239+
hub_url = hub_url or constants.ENDPOINT
240+
hub_url_no_proto = re.sub(r"^https?://", "", hub_url).rstrip("/")
241+
242+
hf_id_no_proto = re.sub(r"^https?://", "", hf_id)
243+
244+
is_hf_url = hf_id_no_proto.startswith(hub_url_no_proto) and "@" not in hf_id
245+
246+
if is_hf_url:
247+
hf_id = hf_id_no_proto[len(hub_url_no_proto):].lstrip("/")
249248

250249
HFFS_PREFIX = "hf://"
251250
if hf_id.startswith(HFFS_PREFIX): # Remove "hf://" prefix if exists
252-
hf_id = hf_id[len(HFFS_PREFIX) :]
251+
hf_id = hf_id[len(HFFS_PREFIX):]
253252

254-
url_segments = hf_id.split("/")
255-
is_hf_id = len(url_segments) <= 3
253+
url_segments = [s for s in hf_id.split("/") if s]
254+
seg_len = len(url_segments)
255+
256+
repo_type: Optional[str] = None
257+
namespace: Optional[str] = None
258+
repo_id: str
256259

257-
namespace: Optional[str]
258260
if is_hf_url:
259-
namespace, repo_id = url_segments[-2:]
260-
if namespace == hub_url:
261-
namespace = None
262-
if len(url_segments) > 2 and hub_url not in url_segments[-3]:
263-
repo_type = url_segments[-3]
264-
elif namespace in constants.REPO_TYPES_MAPPING:
265-
# Mean canonical dataset or model
266-
repo_type = constants.REPO_TYPES_MAPPING[namespace]
261+
if seg_len == 1:
262+
repo_id = url_segments[0]
267263
namespace = None
268-
else:
269264
repo_type = None
270-
elif is_hf_id:
271-
if len(url_segments) == 3:
265+
elif seg_len == 2:
266+
namespace, repo_id = url_segments
267+
repo_type = None
268+
else:
269+
namespace, repo_id = url_segments[-2:]
270+
repo_type = url_segments[-3] if seg_len >= 3 else None
271+
if namespace in constants.REPO_TYPES_MAPPING:
272+
# canonical dataset/model
273+
repo_type = constants.REPO_TYPES_MAPPING[namespace]
274+
namespace = None
275+
276+
elif seg_len <= 3:
277+
if seg_len == 3:
272278
# Passed <repo_type>/<user>/<model_id> or <repo_type>/<org>/<model_id>
273-
repo_type, namespace, repo_id = url_segments[-3:]
274-
elif len(url_segments) == 2:
279+
repo_type, namespace, repo_id = url_segments
280+
elif seg_len == 2:
275281
if url_segments[0] in constants.REPO_TYPES_MAPPING:
276282
# Passed '<model_id>' or 'datasets/<dataset_id>' for a canonical model or dataset
277283
repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]]
278284
namespace = None
279-
repo_id = hf_id.split("/")[-1]
285+
repo_id = url_segments[1]
280286
else:
281287
# Passed <user>/<model_id> or <org>/<model_id>
282-
namespace, repo_id = hf_id.split("/")[-2:]
288+
namespace, repo_id = url_segments
283289
repo_type = None
284290
else:
285-
# Passed <model_id>
286291
repo_id = url_segments[0]
287-
namespace, repo_type = None, None
292+
namespace = None
293+
repo_type = None
288294
else:
289-
raise ValueError(f"Unable to retrieve user and repo ID from the passed HF ID: {hf_id}")
295+
raise ValueError(
296+
f"Unable to retrieve user and repo ID from the passed HF ID: {hf_id}"
297+
)
290298

291299
# Check if repo type is known (mapping "spaces" => "space" + empty value => `None`)
292300
if repo_type in constants.REPO_TYPES_MAPPING:
293301
repo_type = constants.REPO_TYPES_MAPPING[repo_type]
294302
if repo_type == "":
295303
repo_type = None
296-
if repo_type not in constants.REPO_TYPES:
304+
if repo_type not in constants.REPO_TYPES and repo_type is not None:
297305
raise ValueError(f"Unknown `repo_type`: '{repo_type}' ('{input_hf_id}')")
298306

299307
return repo_type, namespace, repo_id

tests/test_hf_api.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2760,28 +2760,34 @@ def test_git_push_end_to_end(self):
27602760
class ParseHFUrlTest(unittest.TestCase):
27612761
def test_repo_type_and_id_from_hf_id_on_correct_values(self):
27622762
possible_values = {
2763-
"http://localhost:8080/hf/user/id": [None, "user", "id"],
2764-
"http://localhost:8080/hf/datasets/user/id": ["dataset", "user", "id"],
2765-
"http://localhost:8080/hf/models/user/id": ["model", "user", "id"],
2766-
"https://huggingface.co/id": [None, None, "id"],
2767-
"https://huggingface.co/user/id": [None, "user", "id"],
2768-
"https://huggingface.co/datasets/user/id": ["dataset", "user", "id"],
2769-
"https://huggingface.co/spaces/user/id": ["space", "user", "id"],
2770-
"user/id": [None, "user", "id"],
2771-
"dataset/user/id": ["dataset", "user", "id"],
2772-
"space/user/id": ["space", "user", "id"],
2773-
"id": [None, None, "id"],
2774-
"hf://id": [None, None, "id"],
2775-
"hf://user/id": [None, "user", "id"],
2776-
"hf://model/user/name": ["model", "user", "name"], # 's' is optional
2777-
"hf://models/user/name": ["model", "user", "name"],
2763+
"saas": {
2764+
"https://huggingface.co/id": [None, None, "id"],
2765+
"https://huggingface.co/user/id": [None, "user", "id"],
2766+
"https://huggingface.co/datasets/user/id": ["dataset", "user", "id"],
2767+
"https://huggingface.co/spaces/user/id": ["space", "user", "id"],
2768+
"user/id": [None, "user", "id"],
2769+
"dataset/user/id": ["dataset", "user", "id"],
2770+
"space/user/id": ["space", "user", "id"],
2771+
"id": [None, None, "id"],
2772+
"hf://id": [None, None, "id"],
2773+
"hf://user/id": [None, "user", "id"],
2774+
"hf://model/user/name": ["model", "user", "name"], # 's' is optional
2775+
"hf://models/user/name": ["model", "user", "name"]
2776+
},
2777+
"self-hosted": {
2778+
"http://localhost:8080/hf/user/id": [None, "user", "id"],
2779+
"http://localhost:8080/hf/datasets/user/id": ["dataset", "user", "id"],
2780+
"http://localhost:8080/hf/models/user/id": ["model", "user", "id"],
2781+
},
27782782
}
27792783

27802784
for key, value in possible_values.items():
2781-
self.assertEqual(
2782-
repo_type_and_id_from_hf_id(key, hub_url=ENDPOINT_PRODUCTION),
2783-
tuple(value),
2784-
)
2785+
hub_url = ENDPOINT_PRODUCTION if key == "saas" else "http://localhost:8080/hf"
2786+
for key, value in value.items():
2787+
self.assertEqual(
2788+
repo_type_and_id_from_hf_id(key, hub_url=hub_url),
2789+
tuple(value),
2790+
)
27852791

27862792
def test_repo_type_and_id_from_hf_id_on_wrong_values(self):
27872793
for hub_id in [

0 commit comments

Comments
 (0)