Skip to content

Commit 12a7578

Browse files
arathi-hlabfacebook-github-bot
authored andcommitted
Adding support for hpu device (#2595)
Summary: Adding a new device support named 'hpu' Pull Request resolved: #2595 Reviewed By: xuzhao9 Differential Revision: D73366654 Pulled By: atalman fbshipit-source-id: 4f925560da9c24b5bc6bf2437b997f4dd89bdd3e
1 parent 87d954e commit 12a7578

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

conftest.py

+5
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def pytest_addoption(parser):
3838
action="store_true",
3939
help="Run benchmarks on mps only and ignore machine configuration checks",
4040
)
41+
parser.addoption(
42+
"--hpu_only",
43+
action="store_true",
44+
help="Run benchmarks on hpu only and ignore machine configuration checks",
45+
)
4146

4247

4348
def set_fuser(fuser):

test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
)
1919
from torchbenchmark.util.metadata_utils import skip_by_metadata
2020

21-
2221
# Some of the models have very heavyweight setup, so we have to set a very
2322
# generous limit. That said, we don't want the entire test suite to hang if
2423
# a single test encounters an extreme failure, so we give up after a test is
@@ -175,6 +174,8 @@ def _load_tests():
175174
model_paths = _list_model_paths()
176175
if os.getenv("USE_CANARY_MODELS"):
177176
model_paths.extend(_list_canary_model_paths())
177+
if hasattr(torch, "hpu") and torch.hpu.is_available():
178+
devices.append("hpu")
178179
for path in model_paths:
179180
# TODO: skipping quantized tests for now due to BC-breaking changes for prepare
180181
# api, enable after PyTorch 1.13 release

test_bench.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
def pytest_generate_tests(metafunc):
2727
# This is where the list of models to test can be configured
2828
# e.g. by using info in metafunc.config
29-
devices = ["cpu", "cuda"]
29+
devices = ["cpu", "cuda", "hpu"]
3030

3131
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
3232
devices.append("mps")
@@ -40,6 +40,9 @@ def pytest_generate_tests(metafunc):
4040
if metafunc.config.option.mps_only:
4141
devices = ["mps"]
4242

43+
if metafunc.config.option.hpu_only:
44+
devices = ["hpu"]
45+
4346
if metafunc.cls and metafunc.cls.__name__ == "TestBenchNetwork":
4447
paths = _list_model_paths()
4548
metafunc.parametrize(

0 commit comments

Comments
 (0)