20
20
$ ./tools/nightly.py checkout -b my-nightly-branch --cuda
21
21
$ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
22
22
23
+ To install the nightly binaries built with ROCm, you can pass in the flag --rocm::
24
+
25
+ $ ./tools/nightly.py checkout -b my-nightly-branch --rocm
26
+ $ source venv/bin/activate # or `& .\venv\Scripts\Activate.ps1` on Windows
27
+
23
28
You can also use this tool to pull the nightly commits into the current branch as
24
29
well. This can be done with::
25
30
@@ -134,6 +139,12 @@ class PipSource(NamedTuple):
134
139
supported_platforms = {"Linux" , "Windows" },
135
140
accelerator = "cuda" ,
136
141
),
142
+ "rocm-6.2.4" : PipSource (
143
+ name = "rocm-6.2.4" ,
144
+ index_url = f"{ PYTORCH_NIGHTLY_PIP_INDEX_URL } /rocm6.2.4" ,
145
+ supported_platforms = {"Linux" },
146
+ accelerator = "rocm" ,
147
+ ),
137
148
}
138
149
139
150
@@ -882,13 +893,26 @@ def find_executable(name: str) -> Path:
882
893
default = argparse .SUPPRESS ,
883
894
metavar = "VERSION" ,
884
895
)
896
+ subparser .add_argument (
897
+ "--rocm" ,
898
+ help = (
899
+ "ROCm version to install "
900
+ "(defaults to the latest version available on the platform)"
901
+ ),
902
+ dest = "rocm" ,
903
+ nargs = "?" ,
904
+ default = argparse .SUPPRESS ,
905
+ metavar = "VERSION" ,
906
+ )
885
907
return parser
886
908
887
909
888
910
def parse_arguments () -> argparse .Namespace :
889
911
parser = make_parser ()
890
912
args = parser .parse_args ()
891
913
args .branch = getattr (args , "branch" , None )
914
+ if hasattr (args , "cuda" ) and hasattr (args , "rocm" ):
915
+ parser .error ("Cannot specify both CUDA and ROCm versions." )
892
916
return args
893
917
894
918
@@ -901,26 +925,32 @@ def main() -> None:
901
925
sys .exit (status )
902
926
903
927
pip_source = None
904
- if hasattr (args , "cuda" ):
905
- available_sources = {
906
- src .name [len ("cuda-" ) :]: src
907
- for src in PIP_SOURCES .values ()
908
- if src .name .startswith ("cuda-" ) and PLATFORM in src .supported_platforms
909
- }
910
- if not available_sources :
911
- print (f"No CUDA versions available on platform { PLATFORM } ." )
912
- sys .exit (1 )
913
- if args .cuda is not None :
914
- pip_source = available_sources .get (args .cuda )
915
- if pip_source is None :
916
- print (
917
- f"CUDA { args .cuda } is not available on platform { PLATFORM } . "
918
- f"Available version(s): { ', ' .join (sorted (available_sources , key = Version ))} "
919
- )
928
+
929
+ for toolkit in ("CUDA" , "ROCm" ):
930
+ accel = toolkit .lower ()
931
+ if hasattr (args , accel ):
932
+ requested = getattr (args , accel )
933
+ available_sources = {
934
+ src .name [len (f"{ accel } -" ) :]: src
935
+ for src in PIP_SOURCES .values ()
936
+ if src .name .startswith (f"{ accel } -" )
937
+ and PLATFORM in src .supported_platforms
938
+ }
939
+ if not available_sources :
940
+ print (f"No { toolkit } versions available on platform { PLATFORM } ." )
920
941
sys .exit (1 )
921
- else :
922
- pip_source = available_sources [max (available_sources , key = Version )]
923
- else :
942
+ if requested is not None :
943
+ pip_source = available_sources .get (requested )
944
+ if pip_source is None :
945
+ print (
946
+ f"{ toolkit } { requested } is not available on platform { PLATFORM } . "
947
+ f"Available version(s): { ', ' .join (sorted (available_sources , key = Version ))} "
948
+ )
949
+ sys .exit (1 )
950
+ else :
951
+ pip_source = available_sources [max (available_sources , key = Version )]
952
+
953
+ if pip_source is None :
924
954
pip_source = PIP_SOURCES ["cpu" ] # always available
925
955
926
956
with logging_manager (debug = args .verbose ) as logger :
0 commit comments