47
47
48
48
_IS_SM8X = False
49
49
_IS_SM9X = False
50
+ _IS_HIPSPARSELT_AVAILABLE = False
50
51
51
52
if torch .cuda .is_available ():
52
53
_IS_SM8X = torch .cuda .get_device_capability (0 )[0 ] == 8
53
54
_IS_SM9X = torch .cuda .get_device_capability (0 )[0 ] == 9
54
-
55
+ _IS_HIPSPARSELT_AVAILABLE = torch . version . hip is not None and tuple ( int ( v ) for v in torch . version . hip . split ( '.' )[: 2 ]) > ( 6 , 4 )
55
56
# CUTLASS kernels only work for Ampere
56
57
if _IS_SM8X :
57
58
SEMI_STRUCTURED_SUPPORTED_BACKENDS ["cutlass" ] = SparseSemiStructuredTensorCUTLASS
58
59
59
60
# add cuSPASRELt tests if available
60
- if torch .backends .cusparselt .is_available () and (_IS_SM8X or _IS_SM9X ):
61
+ if torch .backends .cusparselt .is_available () and (_IS_SM8X or _IS_SM9X or _IS_HIPSPARSELT_AVAILABLE ):
61
62
SEMI_STRUCTURED_SUPPORTED_BACKENDS ["cusparselt" ] = SparseSemiStructuredTensorCUSPARSELT
62
63
63
64
inference_dtypes = dtypes (torch .float16 , torch .bfloat16 , torch .int8 )
@@ -223,6 +224,7 @@ def forward(self, x):
223
224
224
225
@unittest .skipIf (IS_WINDOWS , "torch.compile not supported on windows" )
225
226
@unittest .skipIf ("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS , "cusparselt not supported on this machine" )
227
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
226
228
def test_mlp_contiguous_relu_compile_cusparselt (self ):
227
229
"""
228
230
test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile
@@ -233,6 +235,7 @@ def test_mlp_contiguous_relu_compile_cusparselt(self):
233
235
234
236
@unittest .skipIf ("cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS , "cutlass not supported on this machine" )
235
237
@unittest .skipIf (IS_WINDOWS , "torch.compile not supported on windows" )
238
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
236
239
def test_mlp_contiguous_relu_compile_cutlass (self ):
237
240
"""
238
241
test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile
@@ -243,6 +246,7 @@ def test_mlp_contiguous_relu_compile_cutlass(self):
243
246
244
247
@unittest .skipIf (IS_WINDOWS , "torch.compile not supported on windows" )
245
248
@unittest .skipIf ("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS , "cusparselt not supported on this machine" )
249
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
246
250
def test_sp24_compile (self ) -> None :
247
251
x = torch .randn ([1024 , 512 ], device = "cuda" , dtype = torch .float16 , requires_grad = True )
248
252
@@ -571,6 +575,7 @@ def setUp(self):
571
575
572
576
573
577
@training_dtypes
578
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
574
579
def test_prune_dense_static_sort (self , dtype ) -> None :
575
580
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
576
581
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
@@ -615,6 +620,7 @@ def test_prune_dense_static_sort(self, dtype) -> None:
615
620
616
621
@training_dtypes
617
622
@parametrize_backends
623
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
618
624
def test_pruning_algo_largest_abs_values_greedy (self , dtype , backend ) -> None :
619
625
inp = torch .tensor (
620
626
[[4 , 3 , 2 , 1 ], [- 1 , - 3 , 0.6 , 0.5 ], [1 , 2 , 3 , 4 ], [10 , 2 , - 1 , 5 ]],
@@ -651,6 +657,7 @@ def test_gemm(self, dtype) -> None:
651
657
652
658
@training_dtypes
653
659
@parametrize_backends
660
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
654
661
def test_pack_both_ways_meta_correctness (self , dtype , backend ) -> None :
655
662
M , N = 128 , 256
656
663
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
@@ -684,6 +691,7 @@ def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
684
691
torch .testing .assert_close (ref_gemm , pack_gemm , ** atol_rtol_kw [dtype ])
685
692
686
693
@training_dtypes
694
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
687
695
def test_pack_both_ways_id (self , dtype ) -> None :
688
696
N = 512
689
697
torch .manual_seed (0 )
@@ -718,6 +726,7 @@ def test_pack_both_ways_id(self, dtype) -> None:
718
726
), f"packed_t is wrong at pos: ({ max_diff // N } , { max_diff % N } )"
719
727
720
728
@training_dtypes
729
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
721
730
def test_pack_both_ways_edge_case1 (self , dtype ) -> None :
722
731
# In this case, the heuristic will keep 7 values out of 16
723
732
# instead of 8. let's see how the kernel handles this
@@ -742,6 +751,7 @@ def test_pack_both_ways_edge_case1(self, dtype) -> None:
742
751
assert packed_t [0 , 1 ].item () == 0
743
752
744
753
@training_dtypes
754
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
745
755
def test_sp24_apply (self , dtype ) -> None :
746
756
M , N = 256 , 1024
747
757
x = torch .randn ([M , N ], dtype = dtype , device = "cuda" )
@@ -757,6 +767,7 @@ def test_sp24_apply(self, dtype) -> None:
757
767
torch .testing .assert_close (packed_t , packed_t2 )
758
768
759
769
@training_dtypes
770
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
760
771
def test_sp24_apply_dense (self , dtype ) -> None :
761
772
M , N = 256 , 1024
762
773
x = torch .randn ([M , N ], dtype = dtype , device = "cuda" )
@@ -794,6 +805,7 @@ def test_sp24_apply_dense(self, dtype) -> None:
794
805
795
806
796
807
@training_dtypes
808
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
797
809
def test_sp24_matmuls (self , dtype ) -> None :
798
810
M , N , K = 64 , 256 , 1024
799
811
a = torch .randn ([M , K ], device = "cuda" , dtype = dtype )
@@ -828,6 +840,7 @@ def test_sp24_matmuls(self, dtype) -> None:
828
840
a_s .t () @ a , (a * a_m ).t () @ a , rtol = 1e-1 , atol = 1e-1
829
841
)
830
842
843
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
831
844
def test_sp24_matmuls_mat_vec (self ) -> None :
832
845
a = torch .randn ([64 , 128 ], device = "cuda" , dtype = torch .float16 )
833
846
b = torch .randn ([128 ], device = "cuda" , dtype = torch .float16 )
@@ -837,7 +850,7 @@ def test_sp24_matmuls_mat_vec(self) -> None:
837
850
with pytest .raises (NotImplementedError ):
838
851
torch .testing .assert_close (a_s @ b , (a * a_m ) @ b , ** atol_rtol_kw [a .dtype ])
839
852
840
-
853
+ @ unittest . skipIf ( TEST_WITH_ROCM , "Not supported on ROCm" )
841
854
def test_sp24_matmuls_bmm (self ) -> None :
842
855
a = torch .randn ([64 , 128 ], device = "cuda" , dtype = torch .float16 )
843
856
b = torch .randn ([5 , 6 , 128 ], device = "cuda" , dtype = torch .float16 )
@@ -988,6 +1001,7 @@ def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol):
988
1001
989
1002
@unittest .skipIf (not HAS_GPU , "Inductor+gpu needs triton and recent GPU arch" )
990
1003
@inference_dtypes
1004
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
991
1005
def test_conversions (self , device , dtype ):
992
1006
993
1007
def run_test (r , c , device , dtype ):
@@ -1016,6 +1030,7 @@ def run_test(r, c, device, dtype):
1016
1030
1017
1031
@unittest .skipIf (not HAS_GPU , "Inductor+gpu needs triton and recent GPU arch" )
1018
1032
@inference_dtypes
1033
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
1019
1034
def test_conversions_all_patterns (self , device , dtype ):
1020
1035
r , c = 32 , 128
1021
1036
@@ -1135,6 +1150,7 @@ def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device):
1135
1150
1136
1151
@unittest .skip ("cuSPARSELt v0.6.x does not support bfloat/float16 alpha scaling" )
1137
1152
@training_dtypes
1153
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
1138
1154
def test_cslt_sparse_mm_alpha (self , dtype , device ):
1139
1155
A = torch .Tensor ([0 , 0 , 1 , 1 ]).tile ((128 , 64 )).to (dtype ).cuda ()
1140
1156
B = torch .ones ((256 , 128 ), device = device ).to (dtype )
@@ -1151,6 +1167,7 @@ def test_cslt_sparse_mm_alpha(self, dtype, device):
1151
1167
torch .testing .assert_close (sparse_result , dense_result , rtol = 1e-3 , atol = 1e-3 )
1152
1168
1153
1169
@parametrize ("out_dtype" , [torch .float16 , torch .bfloat16 , torch .int32 ])
1170
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
1154
1171
def test_cslt_sparse_mm_alpha_compile_autotune (self , device , out_dtype ):
1155
1172
A = torch .Tensor ([0 , 0 , 1 , 1 ]).tile ((128 , 64 )).to (torch .int8 ).to (device )
1156
1173
B = torch .ones ((128 , 256 ), device = device , dtype = torch .int8 ).t ()
@@ -1172,6 +1189,7 @@ def get_dense_result():
1172
1189
torch .testing .assert_close (sparse_result .cpu (), get_dense_result (), rtol = 1e-3 , atol = 1e-3 )
1173
1190
1174
1191
@parametrize ("out_dtype" , [torch .float16 , torch .bfloat16 , torch .int32 ])
1192
+ @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
1175
1193
def test_cslt_sparse_mm_alpha_mixed_dtype (self , out_dtype , device ):
1176
1194
A = torch .Tensor ([0 , 0 , 10 , 10 ]).tile ((128 , 64 )).to (torch .int8 ).cuda ()
1177
1195
B = torch .ones ((128 , 256 ), device = device ).to (torch .int8 ).t ()
0 commit comments