Skip to content

Commit 43390d8

Browse files
Peter Y. Yehpytorchmergebot
authored andcommitted
ROCm Sparsity through HipSparseLT (pytorch#150578)
TLDR: - This pull request introduces support for hipSPARSELt in ROCm, current usage would be semi-structure sparsity. - Require **ROCm 6.4** && **gfx942/gfx950**. - The average performance uplift (compare to dense operation) is ~ 20% in ROCm 6.4 but expect further performance lift along the way. ### Dense vs. Sparse Performance Comparison #### **NT (Row-major)** **Average Uplift**: `1.20` | M | N | K | hipsparselt-bench (us) | hipblaslt-bench get all (us) | Uplift | |-------|--------|--------|-------------------------|-------------------------------|--------| | 14336 | 8 | 4096 | 20.05 | 25.3 | 1.26 | | 4096 | 8 | 14336 | 21.07 | 25.28 | 1.20 | | 3072 | 3072 | 10240 | 299.05 | 351.82 | 1.18 | | 3072 | 1536 | 768 | 18.56 | 20.05 | 1.08 | | 3072 | 17664 | 768 | 163.13 | 173.91 | 1.07 | | 3072 | 196608 | 768 | 1717.30 | 1949.63 | 1.14 | | 3072 | 24576 | 768 | 206.84 | 242.98 | 1.17 | | 3072 | 6144 | 768 | 53.90 | 56.88 | 1.06 | | 3072 | 98304 | 768 | 833.77 | 962.28 | 1.15 | | 768 | 1536 | 768 | 8.53 | 19.65 | 2.30 | | 768 | 17664 | 768 | 46.02 | 46.84 | 1.02 | | 768 | 196608 | 768 | 463.15 | 540.46 | 1.17 | | 768 | 24576 | 768 | 54.32 | 59.55 | 1.10 | | 768 | 6144 | 768 | 19.47 | 20.15 | 1.03 | | 768 | 98304 | 768 | 231.88 | 258.73 | 1.12 | --- #### **NN (Row-major)** **Average Uplift**: `1.13` | M | N | K | hipsparselt-bench (us) | hipblaslt-bench get all (us) | Uplift | |-----|--------|-------|-------------------------|-------------------------------|--------| | 768 | 1536 | 3072 | 27.50 | 28.78 | 1.05 | | 768 | 17664 | 3072 | 125.06 | 158.94 | 1.27 | | 768 | 196608 | 3072 | 1568.38 | 1767.12 | 1.13 | | 768 | 24576 | 3072 | 171.05 | 203.49 | 1.19 | | 768 | 6144 | 3072 | 58.72 | 60.39 | 1.03 | | 768 | 98304 | 3072 | 787.15 | 887.60 | 1.13 | ------------------------- This pull request introduces support for hipSPARSELt in ROCm, alongside various updates and improvements to the codebase and test suite. The changes primarily involve adding configuration flags, updating conditional checks, and ensuring compatibility with hipSPARSELt. ### ROCm and hipSPARSELt Support: * [`BUILD.bazel`](diffhunk://#diff-7fc57714ef13c3325ce2a1130202edced92fcccc0c6db34a72f7b57f60d552a3R292): Added `@AT_HIPSPARSELT_ENABLED@` substitution to enable hipSPARSELt support. * [`aten/CMakeLists.txt`](diffhunk://#diff-0604597797bb21d7c39150f9429d6b2ace10b79ab308514ad03f76153ae8249bR104-R110): Introduced a conditional flag to enable hipSPARSELt support based on ROCm version. * [`aten/src/ATen/CMakeLists.txt`](diffhunk://#diff-ce80f3115ab2f6be5142f0678a1fc92c6b2d7727766ce44f48726c99e720f777R37): Added `AT_HIPSPARSELT_ENABLED` configuration. * [`aten/src/ATen/cuda/CUDAConfig.h.in`](diffhunk://#diff-8bb82da825ca87c28233abacffa1b0566c73a54990b7a77f3f5108d3718fea15R11): Defined `AT_HIPSPARSELT_ENABLED` macro. * `caffe2/CMakeLists.txt`, `cmake/Dependencies.cmake`, `cmake/public/LoadHIP.cmake`: Included hipSPARSELt in the ROCm dependencies. [[1]](diffhunk://#diff-c5ee05f1e918772792ff6f2a3f579fc2f182e57b1709fd786ef6dc711fd68b27R1380) [[2]](diffhunk://#diff-12e8125164bbfc7556b1781a8ed516e333cc0bf058acb7197f7415be44606c72L1084-R1084) [[3]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5R153) ### Codebase Updates: * [`aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp`](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R1-R6): Added hipSPARSELt support checks and initialization functions. Updated various methods to conditionally handle hipSPARSELt. [[1]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R1-R6) [[2]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R22-R67) [[3]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R78-R85) [[4]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R97-R109) [[5]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R183-R188) [[6]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3L134-R200) [[7]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R213-R222) [[8]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3L217-R285) ### Test Suite Updates: * [`test/test_sparse_semi_structured.py`](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR50-R65): Added checks for hipSPARSELt availability and updated test conditions to skip tests not supported on ROCm. [[1]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR50-R65) [[2]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR228) [[3]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR239) [[4]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR250) [[5]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR579) [[6]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR624) [[7]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR661) [[8]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR695) [[9]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR730) [[10]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR755) [[11]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR771) [[12]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR809) [[13]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR844) [[14]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cL840-R854) [[15]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR1005) Pull Request resolved: pytorch#150578 Approved by: https://github.com/jeffdaily
1 parent ad26ec6 commit 43390d8

File tree

9 files changed

+130
-10
lines changed

9 files changed

+130
-10
lines changed

BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ header_template_rule(
290290
substitutions = {
291291
"@AT_CUDNN_ENABLED@": "1",
292292
"@AT_CUSPARSELT_ENABLED@": "0",
293+
"@AT_HIPSPARSELT_ENABLED@": "0",
293294
"@AT_ROCM_ENABLED@": "0",
294295
"@AT_MAGMA_ENABLED@": "0",
295296
"@NVCC_FLAGS_EXTRA@": "",

aten/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@ else()
101101
set(AT_CUSPARSELT_ENABLED 1)
102102
endif()
103103

104+
# Add hipSPARSELt support flag
105+
if(USE_ROCM AND ROCM_VERSION VERSION_GREATER_EQUAL "6.4.0")
106+
set(AT_HIPSPARSELT_ENABLED 1)
107+
else()
108+
set(AT_HIPSPARSELT_ENABLED 0)
109+
endif()
110+
104111
list(APPEND ATen_CPU_INCLUDE
105112
${CMAKE_CURRENT_SOURCE_DIR}/src)
106113
add_subdirectory(src/ATen)

aten/src/ATen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ set_bool(AT_MAGMA_ENABLED USE_MAGMA)
3434
set_bool(CAFFE2_STATIC_LINK_CUDA_INT CAFFE2_STATIC_LINK_CUDA)
3535
set_bool(AT_CUDNN_ENABLED CAFFE2_USE_CUDNN)
3636
set_bool(AT_CUSPARSELT_ENABLED CAFFE2_USE_CUSPARSELT)
37+
set_bool(AT_HIPSPARSELT_ENABLED CAFFE2_USE_HIPSPARSELT)
3738

3839
configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h")
3940
# TODO: Do not generate CUDAConfig.h for ROCm BUILDS

aten/src/ATen/cuda/CUDAConfig.h.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
// only be included from C++ files.
99
#define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@
1010
#define AT_CUSPARSELT_ENABLED() @AT_CUSPARSELT_ENABLED@
11+
#define AT_HIPSPARSELT_ENABLED() @AT_HIPSPARSELT_ENABLED@
1112
#define AT_ROCM_ENABLED() @AT_ROCM_ENABLED@
1213
#define AT_MAGMA_ENABLED() @AT_MAGMA_ENABLED@
1314

aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
2-
2+
#include <unordered_map>
3+
#include <mutex>
4+
#include <string_view>
35
#if AT_CUSPARSELT_ENABLED()
46

57
namespace at::native {
@@ -15,6 +17,45 @@ namespace at::native {
1517
thread_local cusparseLtHandle_t handle;
1618
thread_local bool handle_initialized = false;
1719

20+
#ifdef USE_ROCM
21+
// Single global flag for platform-wide hipSparseLt support
22+
c10::once_flag g_hipSparseLtSupportInitFlag;
23+
static bool g_hipSparseLtSupported = false;
24+
25+
// Initialize the hipSparseLt support status once for the platform
26+
static void initHipSparseLtSupport() {
27+
// Default to not supported
28+
g_hipSparseLtSupported = false;
29+
30+
// Check only the first available device
31+
try {
32+
if (at::cuda::device_count() > 0) {
33+
g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx942"}, 0);
34+
}
35+
} catch (const std::exception&) {
36+
// If an exception occurs during device property check, we assume hipSparseLt is not supported
37+
// This could happen due to driver issues, device access problems, or other runtime errors
38+
g_hipSparseLtSupported = false;
39+
TORCH_WARN("Exception occurred while checking hipSparseLt support. Assuming not supported.");
40+
}
41+
}
42+
43+
static bool isHipSparseLtSupported() {
44+
// Initialize support check only once
45+
c10::call_once(g_hipSparseLtSupportInitFlag, initHipSparseLtSupport);
46+
47+
// Return cached result (platform-wide)
48+
if (!g_hipSparseLtSupported) {
49+
TORCH_CHECK(
50+
false,
51+
"hipSparseLt not supported on this device, supported architectures: "
52+
"gfx950, gfx942. "
53+
"required ROCM version: 6.4.0 or later.");
54+
}
55+
return g_hipSparseLtSupported;
56+
}
57+
#endif
58+
1859
at::Tensor _cslt_compress(const Tensor& sparse_input) {
1960
if (!handle_initialized) {
2061
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle));
@@ -25,6 +66,10 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) {
2566
cudaDataType type;
2667
auto compression_factor = 9;
2768

69+
#ifdef USE_ROCM
70+
TORCH_CHECK(isHipSparseLtSupported());
71+
#endif
72+
2873
switch (sparse_input.scalar_type()) {
2974
case at::ScalarType::Char:
3075
type = CUDA_R_8I;
@@ -36,17 +81,19 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) {
3681
case at::ScalarType::BFloat16:
3782
type = CUDA_R_16BF;
3883
break;
84+
#ifndef USE_ROCM
3985
case at::ScalarType::Float:
4086
type = CUDA_R_32F;
4187
break;
42-
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
88+
#endif
89+
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM)
4390
case at::ScalarType::Float8_e4m3fn:
4491
type = CUDA_R_8F_E4M3;
4592
compression_factor = 10;
4693
break;
4794
#endif
4895
default:
49-
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix");
96+
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt/hipSparseLt compressed matrix");
5097
break;
5198
}
5299

@@ -120,6 +167,10 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
120167
cusparseComputeType compute_type;
121168
auto compression_factor = 9;
122169

170+
#ifdef USE_ROCM
171+
TORCH_CHECK(isHipSparseLtSupported());
172+
#endif
173+
123174
switch (compressed_A.scalar_type()) {
124175
case at::ScalarType::Char:
125176
input_type = CUDA_R_8I;
@@ -131,7 +182,7 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
131182

132183
// cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F
133184
// to CUSPARSE_COMPUTE_32F
134-
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502
185+
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502 || defined(USE_ROCM)
135186
case at::ScalarType::Half:
136187
input_type = CUDA_R_16F;
137188
output_type = CUDA_R_16F;
@@ -144,14 +195,16 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
144195
C_type = CUDA_R_16BF;
145196
compute_type = CUSPARSE_COMPUTE_32F;
146197
break;
198+
#ifndef USE_ROCM
147199
case at::ScalarType::Float:
148200
input_type = CUDA_R_32F;
149201
output_type = CUDA_R_32F;
150202
C_type = CUDA_R_32F;
151203
compute_type = CUSPARSE_COMPUTE_32F;
152204
break;
205+
#endif
153206
// if cuSPARSELt >= 6.2.3, we can add Float8 support
154-
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
207+
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM)
155208
case at::ScalarType::Float8_e4m3fn:
156209
input_type = CUDA_R_8F_E4M3;
157210
output_type = CUDA_R_8F_E4M3;
@@ -214,7 +267,7 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
214267
}
215268
}
216269
// cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support
217-
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
270+
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM)
218271
else if (input_type == CUDA_R_8F_E4M3) {
219272
switch (out_dtype) {
220273
case at::ScalarType::Float8_e4m3fn:

cmake/Dependencies.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,7 @@ if(USE_ROCM)
10631063

10641064
# Math libraries
10651065
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
1066-
roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver roc::hipblaslt)
1066+
roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsparselt roc::hipsolver roc::hipblaslt)
10671067

10681068
# ---[ Kernel asserts
10691069
# Kernel asserts is disabled for ROCm by default.

cmake/public/LoadHIP.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ if(HIP_FOUND)
151151
find_package_and_print_version(miopen REQUIRED)
152152
find_package_and_print_version(hipfft REQUIRED)
153153
find_package_and_print_version(hipsparse REQUIRED)
154+
find_package_and_print_version(hipsparselt REQUIRED)
154155
find_package_and_print_version(rocprim REQUIRED)
155156
find_package_and_print_version(hipcub REQUIRED)
156157
find_package_and_print_version(rocthrust REQUIRED)

test/test_sparse_semi_structured.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,18 @@
4747

4848
_IS_SM8X = False
4949
_IS_SM9X = False
50+
_IS_HIPSPARSELT_AVAILABLE = False
5051

5152
if torch.cuda.is_available():
5253
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
5354
_IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9
54-
55+
_IS_HIPSPARSELT_AVAILABLE = torch.version.hip is not None and tuple(int(v) for v in torch.version.hip.split('.')[:2]) > (6, 4)
5556
# CUTLASS kernels only work for Ampere
5657
if _IS_SM8X:
5758
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
5859

5960
# add cuSPASRELt tests if available
60-
if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X):
61+
if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X or _IS_HIPSPARSELT_AVAILABLE):
6162
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
6263

6364
inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.int8)
@@ -223,6 +224,7 @@ def forward(self, x):
223224

224225
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
225226
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
227+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
226228
def test_mlp_contiguous_relu_compile_cusparselt(self):
227229
"""
228230
test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile
@@ -233,6 +235,7 @@ def test_mlp_contiguous_relu_compile_cusparselt(self):
233235

234236
@unittest.skipIf("cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cutlass not supported on this machine")
235237
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
238+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
236239
def test_mlp_contiguous_relu_compile_cutlass(self):
237240
"""
238241
test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile
@@ -243,6 +246,7 @@ def test_mlp_contiguous_relu_compile_cutlass(self):
243246

244247
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
245248
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
249+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
246250
def test_sp24_compile(self) -> None:
247251
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
248252

@@ -571,6 +575,7 @@ def setUp(self):
571575

572576

573577
@training_dtypes
578+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
574579
def test_prune_dense_static_sort(self, dtype) -> None:
575580
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
576581
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
@@ -615,6 +620,7 @@ def test_prune_dense_static_sort(self, dtype) -> None:
615620

616621
@training_dtypes
617622
@parametrize_backends
623+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
618624
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
619625
inp = torch.tensor(
620626
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
@@ -651,6 +657,7 @@ def test_gemm(self, dtype) -> None:
651657

652658
@training_dtypes
653659
@parametrize_backends
660+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
654661
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
655662
M, N = 128, 256
656663
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
@@ -684,6 +691,7 @@ def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
684691
torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype])
685692

686693
@training_dtypes
694+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
687695
def test_pack_both_ways_id(self, dtype) -> None:
688696
N = 512
689697
torch.manual_seed(0)
@@ -718,6 +726,7 @@ def test_pack_both_ways_id(self, dtype) -> None:
718726
), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})"
719727

720728
@training_dtypes
729+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
721730
def test_pack_both_ways_edge_case1(self, dtype) -> None:
722731
# In this case, the heuristic will keep 7 values out of 16
723732
# instead of 8. let's see how the kernel handles this
@@ -742,6 +751,7 @@ def test_pack_both_ways_edge_case1(self, dtype) -> None:
742751
assert packed_t[0, 1].item() == 0
743752

744753
@training_dtypes
754+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
745755
def test_sp24_apply(self, dtype) -> None:
746756
M, N = 256, 1024
747757
x = torch.randn([M, N], dtype=dtype, device="cuda")
@@ -757,6 +767,7 @@ def test_sp24_apply(self, dtype) -> None:
757767
torch.testing.assert_close(packed_t, packed_t2)
758768

759769
@training_dtypes
770+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
760771
def test_sp24_apply_dense(self, dtype) -> None:
761772
M, N = 256, 1024
762773
x = torch.randn([M, N], dtype=dtype, device="cuda")
@@ -794,6 +805,7 @@ def test_sp24_apply_dense(self, dtype) -> None:
794805

795806

796807
@training_dtypes
808+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
797809
def test_sp24_matmuls(self, dtype) -> None:
798810
M, N, K = 64, 256, 1024
799811
a = torch.randn([M, K], device="cuda", dtype=dtype)
@@ -828,6 +840,7 @@ def test_sp24_matmuls(self, dtype) -> None:
828840
a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1
829841
)
830842

843+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
831844
def test_sp24_matmuls_mat_vec(self) -> None:
832845
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
833846
b = torch.randn([128], device="cuda", dtype=torch.float16)
@@ -837,7 +850,7 @@ def test_sp24_matmuls_mat_vec(self) -> None:
837850
with pytest.raises(NotImplementedError):
838851
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
839852

840-
853+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
841854
def test_sp24_matmuls_bmm(self) -> None:
842855
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
843856
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
@@ -988,6 +1001,7 @@ def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol):
9881001

9891002
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
9901003
@inference_dtypes
1004+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
9911005
def test_conversions(self, device, dtype):
9921006

9931007
def run_test(r, c, device, dtype):
@@ -1016,6 +1030,7 @@ def run_test(r, c, device, dtype):
10161030

10171031
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
10181032
@inference_dtypes
1033+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
10191034
def test_conversions_all_patterns(self, device, dtype):
10201035
r, c = 32, 128
10211036

@@ -1135,6 +1150,7 @@ def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device):
11351150

11361151
@unittest.skip("cuSPARSELt v0.6.x does not support bfloat/float16 alpha scaling")
11371152
@training_dtypes
1153+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
11381154
def test_cslt_sparse_mm_alpha(self, dtype, device):
11391155
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda()
11401156
B = torch.ones((256, 128), device=device).to(dtype)
@@ -1151,6 +1167,7 @@ def test_cslt_sparse_mm_alpha(self, dtype, device):
11511167
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
11521168

11531169
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
1170+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
11541171
def test_cslt_sparse_mm_alpha_compile_autotune(self, device, out_dtype):
11551172
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(torch.int8).to(device)
11561173
B = torch.ones((128, 256), device=device, dtype=torch.int8).t()
@@ -1172,6 +1189,7 @@ def get_dense_result():
11721189
torch.testing.assert_close(sparse_result.cpu(), get_dense_result(), rtol=1e-3, atol=1e-3)
11731190

11741191
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
1192+
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
11751193
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
11761194
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
11771195
B = torch.ones((128, 256), device=device).to(torch.int8).t()

0 commit comments

Comments
 (0)