Skip to content

Commit c4bff71

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[Easy] Add ROCm support to nightly pull tool (pytorch#141282)
Pull Request resolved: pytorch#141282 Approved by: https://github.com/malfet ghstack dependencies: pytorch#143263
1 parent 8059d56 commit c4bff71

File tree

4 files changed

+64
-20
lines changed

4 files changed

+64
-20
lines changed

.github/scripts/generate_binary_build_matrix.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
CUDA_ARCHES_FULL_VERSION = {"11.8": "11.8.0", "12.4": "12.4.1", "12.6": "12.6.3"}
2121
CUDA_ARCHES_CUDNN_VERSION = {"11.8": "9", "12.4": "9", "12.6": "9"}
2222

23+
# NOTE: Also update the ROCm sources in tools/nightly.py when changing this list
2324
ROCM_ARCHES = ["6.2.4", "6.3"]
2425

2526
XPU_ARCHES = ["xpu"]

CONTRIBUTING.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ git clone [email protected]:<USERNAME>/pytorch.git
7878
cd pytorch
7979
git remote add upstream [email protected]:pytorch/pytorch.git
8080

81-
make setup-env # or make setup-env-cuda for pre-built CUDA binaries
81+
make setup-env
82+
# Or run `make setup-env-cuda` for pre-built CUDA binaries
83+
# Or run `make setup-env-rocm` for pre-built ROCm binaries
8284
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
8385
```
8486

@@ -193,6 +195,13 @@ To install the nightly binaries built with CUDA, you can pass in the flag `--cud
193195
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
194196
```
195197

198+
To install the nightly binaries built with ROCm, you can pass in the flag `--rocm`:
199+
200+
```bash
201+
./tools/nightly.py checkout -b my-nightly-branch --rocm
202+
source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
203+
```
204+
196205
You can also use this tool to pull the nightly commits into the current branch:
197206

198207
```bash

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ setup-env: ensure-branch-clean
3535
setup-env-cuda:
3636
$(MAKE) setup-env PYTHON="$(PYTHON)" NIGHTLY_TOOL_OPTS="$(NIGHTLY_TOOL_OPTS) --cuda"
3737

38+
setup-env-rocm:
39+
$(MAKE) setup-env PYTHON="$(PYTHON)" NIGHTLY_TOOL_OPTS="$(NIGHTLY_TOOL_OPTS) --rocm"
40+
3841
setup_env: setup-env
3942
setup_env_cuda: setup-env-cuda
43+
setup_env_rocm: setup-env-rocm
4044

4145
setup-lint:
4246
$(PIP) install lintrunner

tools/nightly.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@
2020
$ ./tools/nightly.py checkout -b my-nightly-branch --cuda
2121
$ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
2222
23+
To install the nightly binaries built with ROCm, you can pass in the flag --rocm::
24+
25+
$ ./tools/nightly.py checkout -b my-nightly-branch --rocm
26+
$ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
27+
2328
You can also use this tool to pull the nightly commits into the current branch as
2429
well. This can be done with::
2530
@@ -134,6 +139,12 @@ class PipSource(NamedTuple):
134139
supported_platforms={"Linux", "Windows"},
135140
accelerator="cuda",
136141
),
142+
"rocm-6.2.4": PipSource(
143+
name="rocm-6.2.4",
144+
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/rocm6.2.4",
145+
supported_platforms={"Linux"},
146+
accelerator="rocm",
147+
),
137148
}
138149

139150

@@ -882,13 +893,26 @@ def find_executable(name: str) -> Path:
882893
default=argparse.SUPPRESS,
883894
metavar="VERSION",
884895
)
896+
subparser.add_argument(
897+
"--rocm",
898+
help=(
899+
"ROCm version to install "
900+
"(defaults to the latest version available on the platform)"
901+
),
902+
dest="rocm",
903+
nargs="?",
904+
default=argparse.SUPPRESS,
905+
metavar="VERSION",
906+
)
885907
return parser
886908

887909

888910
def parse_arguments() -> argparse.Namespace:
889911
parser = make_parser()
890912
args = parser.parse_args()
891913
args.branch = getattr(args, "branch", None)
914+
if hasattr(args, "cuda") and hasattr(args, "rocm"):
915+
parser.error("Cannot specify both CUDA and ROCm versions.")
892916
return args
893917

894918

@@ -901,26 +925,32 @@ def main() -> None:
901925
sys.exit(status)
902926

903927
pip_source = None
904-
if hasattr(args, "cuda"):
905-
available_sources = {
906-
src.name[len("cuda-") :]: src
907-
for src in PIP_SOURCES.values()
908-
if src.name.startswith("cuda-") and PLATFORM in src.supported_platforms
909-
}
910-
if not available_sources:
911-
print(f"No CUDA versions available on platform {PLATFORM}.")
912-
sys.exit(1)
913-
if args.cuda is not None:
914-
pip_source = available_sources.get(args.cuda)
915-
if pip_source is None:
916-
print(
917-
f"CUDA {args.cuda} is not available on platform {PLATFORM}. "
918-
f"Available version(s): {', '.join(sorted(available_sources, key=Version))}"
919-
)
928+
929+
for toolkit in ("CUDA", "ROCm"):
930+
accel = toolkit.lower()
931+
if hasattr(args, accel):
932+
requested = getattr(args, accel)
933+
available_sources = {
934+
src.name[len(f"{accel}-") :]: src
935+
for src in PIP_SOURCES.values()
936+
if src.name.startswith(f"{accel}-")
937+
and PLATFORM in src.supported_platforms
938+
}
939+
if not available_sources:
940+
print(f"No {toolkit} versions available on platform {PLATFORM}.")
920941
sys.exit(1)
921-
else:
922-
pip_source = available_sources[max(available_sources, key=Version)]
923-
else:
942+
if requested is not None:
943+
pip_source = available_sources.get(requested)
944+
if pip_source is None:
945+
print(
946+
f"{toolkit} {requested} is not available on platform {PLATFORM}. "
947+
f"Available version(s): {', '.join(sorted(available_sources, key=Version))}"
948+
)
949+
sys.exit(1)
950+
else:
951+
pip_source = available_sources[max(available_sources, key=Version)]
952+
953+
if pip_source is None:
924954
pip_source = PIP_SOURCES["cpu"] # always available
925955

926956
with logging_manager(debug=args.verbose) as logger:

0 commit comments

Comments
 (0)