diff --git a/.github/workflows/update-s3-html.yml b/.github/workflows/update-s3-html.yml
index 85ef57d5d7..a64d33855b 100644
--- a/.github/workflows/update-s3-html.yml
+++ b/.github/workflows/update-s3-html.yml
@@ -17,6 +17,7 @@ jobs:
strategy:
matrix:
prefix: ["whl", "whl/test", "whl/nightly", "libtorch", "libtorch/nightly"]
+ subdir-pattern: ['cu[0-9]+', 'rocm[0-9]+\.[0-9]+', 'cpu', 'xpu']
fail-fast: False
container:
image: continuumio/miniconda3:23.10.0-1
@@ -42,4 +43,4 @@ jobs:
# Install requirements
pip install -r s3_management/requirements.txt
- python s3_management/manage.py ${{ matrix.prefix }}
+ python s3_management/manage.py --subdir-pattern '${{ matrix.subdir-pattern }}' ${{ matrix.prefix }}
diff --git a/s3_management/assume-role-output.txt b/s3_management/assume-role-output.txt
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/s3_management/manage.py b/s3_management/manage.py
index 1f5ae9e16c..5262e933bc 100755
--- a/s3_management/manage.py
+++ b/s3_management/manage.py
@@ -8,7 +8,7 @@
import time
from collections import defaultdict
from os import makedirs, path
-from re import match, sub
+from re import compile, match, sub
from typing import Dict, Iterable, List, Optional, Set, TypeVar
import boto3 # type: ignore[import]
@@ -630,15 +630,12 @@ def grant_public_read(cls, key: str) -> None:
CLIENT.put_object_acl(Bucket=BUCKET.name, Key=key, ACL="public-read")
@classmethod
- def fetch_object_names(cls, prefix: str) -> List[str]:
+ def fetch_object_names(cls, prefix: str, pattern: str) -> List[str]:
obj_names = []
for obj in BUCKET.objects.filter(Prefix=prefix):
is_acceptable = any(
[path.dirname(obj.key) == prefix]
- + [
- match(f"{prefix}/{pattern}", path.dirname(obj.key))
- for pattern in ACCEPTED_SUBDIR_PATTERNS
- ]
+ + [match(compile(f"{prefix}/{pattern}"), path.dirname(obj.key))]
) and obj.key.endswith(ACCEPTED_FILE_EXTENSIONS)
if not is_acceptable:
continue
@@ -706,9 +703,11 @@ def _fetch_metadata(key: str) -> str:
self.objects[idx].pep658 = response
@classmethod
- def from_S3(cls, prefix: str, with_metadata: bool = True) -> "S3Index":
+ def from_S3(
+ cls, prefix: str, pattern: str, with_metadata: bool = True
+ ) -> "S3Index":
prefix = prefix.rstrip("/")
- obj_names = cls.fetch_object_names(prefix)
+ obj_names = cls.fetch_object_names(prefix, pattern)
def sanitize_key(key: str) -> str:
return key.replace("+", "%2B")
@@ -749,6 +748,12 @@ def undelete_prefix(cls, prefix: str) -> None:
def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser("Manage S3 HTML indices for PyTorch")
parser.add_argument("prefix", type=str, choices=PREFIXES + ["all"])
+ parser.add_argument(
+ "--subdir-pattern",
+ type=str,
+ choices=ACCEPTED_SUBDIR_PATTERNS + ["all"],
+ default="all",
+ )
parser.add_argument("--do-not-upload", action="store_true")
parser.add_argument("--compute-sha256", action="store_true")
return parser
@@ -761,14 +766,22 @@ def main() -> None:
if args.compute_sha256:
action = "Computing checksums"
+ patterns = (
+ ACCEPTED_SUBDIR_PATTERNS
+ if args.subdir_pattern == "all"
+ else [args.subdir_pattern]
+ )
prefixes = PREFIXES if args.prefix == "all" else [args.prefix]
for prefix in prefixes:
generate_pep503 = prefix.startswith("whl")
print(f"INFO: {action} for '{prefix}'")
stime = time.time()
- idx = S3Index.from_S3(
- prefix=prefix, with_metadata=generate_pep503 or args.compute_sha256
- )
+ for pattern in patterns:
+ idx = S3Index.from_S3(
+ prefix=prefix,
+ pattern=pattern,
+ with_metadata=generate_pep503 or args.compute_sha256,
+ )
etime = time.time()
print(
f"DEBUG: Fetched {len(idx.objects)} objects for '{prefix}' in {etime-stime:.2f} seconds"