File tree 3 files changed +11
-2
lines changed
3 files changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -38,6 +38,11 @@ def pytest_addoption(parser):
38
38
action = "store_true" ,
39
39
help = "Run benchmarks on mps only and ignore machine configuration checks" ,
40
40
)
41
+ parser .addoption (
42
+ "--hpu_only" ,
43
+ action = "store_true" ,
44
+ help = "Run benchmarks on hpu only and ignore machine configuration checks" ,
45
+ )
41
46
42
47
43
48
def set_fuser (fuser ):
Original file line number Diff line number Diff line change 18
18
)
19
19
from torchbenchmark .util .metadata_utils import skip_by_metadata
20
20
21
-
22
21
# Some of the models have very heavyweight setup, so we have to set a very
23
22
# generous limit. That said, we don't want the entire test suite to hang if
24
23
# a single test encounters an extreme failure, so we give up after a test is
@@ -175,6 +174,8 @@ def _load_tests():
175
174
model_paths = _list_model_paths ()
176
175
if os .getenv ("USE_CANARY_MODELS" ):
177
176
model_paths .extend (_list_canary_model_paths ())
177
+ if hasattr (torch , "hpu" ) and torch .hpu .is_available ():
178
+ devices .append ("hpu" )
178
179
for path in model_paths :
179
180
# TODO: skipping quantized tests for now due to BC-breaking changes for prepare
180
181
# api, enable after PyTorch 1.13 release
Original file line number Diff line number Diff line change 26
26
def pytest_generate_tests (metafunc ):
27
27
# This is where the list of models to test can be configured
28
28
# e.g. by using info in metafunc.config
29
- devices = ["cpu" , "cuda" ]
29
+ devices = ["cpu" , "cuda" , "hpu" ]
30
30
31
31
if hasattr (torch .backends , "mps" ) and torch .backends .mps .is_available ():
32
32
devices .append ("mps" )
@@ -40,6 +40,9 @@ def pytest_generate_tests(metafunc):
40
40
if metafunc .config .option .mps_only :
41
41
devices = ["mps" ]
42
42
43
+ if metafunc .config .option .hpu_only :
44
+ devices = ["hpu" ]
45
+
43
46
if metafunc .cls and metafunc .cls .__name__ == "TestBenchNetwork" :
44
47
paths = _list_model_paths ()
45
48
metafunc .parametrize (
You can’t perform that action at this time.
0 commit comments