@@ -71,6 +71,120 @@ if CUSPARSE.version() ≥ v"11.7.2"
71
71
end
72
72
73
73
74
+ @testset " C = αAᵀBᵀ + βC" begin
75
+ A1 = CuSparseMatrixCSR {elty} (sprand (elty, k, m, p))
76
+ A2 = copy (A1)
77
+ A2. nzVal = CUDA. rand (elty, size (A2. nzVal)... )
78
+ A = cat (A1, A2; dims= 3 )
79
+
80
+ B = CUDA. rand (elty, n, k, 2 )
81
+ C = CUDA. rand (elty, m, n, 2 )
82
+ D = copy (C)
83
+
84
+ CUSPARSE. bmm! (' C' , ' C' , α, A, B, β, C, ' O' )
85
+
86
+ D[:,:,1 ] = α * A1' * B[:,:,1 ]' + β * D[:,:,1 ]
87
+ D[:,:,2 ] = α * A2' * B[:,:,2 ]' + β * D[:,:,2 ]
88
+
89
+ @test D ≈ C
90
+ end
91
+
92
+ @testset " extended batch-dims" begin
93
+ A1 = CuSparseMatrixCSR {elty} (sprand (elty, m, k, p))
94
+ A2 = copy (A1)
95
+ A2. nzVal = CUDA. rand (elty, size (A2. nzVal)... )
96
+ A3 = cat (A1, A2; dims= 3 )
97
+
98
+ A4 = copy (A3)
99
+ A4. nzVal = CUDA. rand (elty, size (A3. nzVal)... )
100
+
101
+ A5 = copy (A3)
102
+ A5. nzVal = CUDA. rand (elty, size (A3. nzVal)... )
103
+
104
+ A = cat (A3, A4, A5; dims= 4 )
105
+
106
+ B = CUDA. rand (elty, k, n, 2 , 3 )
107
+ C = CUDA. rand (elty, m, n, 2 , 3 )
108
+ D = copy (C)
109
+
110
+ CUSPARSE. bmm! (' N' , ' N' , α, A, B, β, C, ' O' )
111
+
112
+ for c in CartesianIndices ((2 ,3 ))
113
+ CUDA. @allowscalar D[:,:,c] = α * A[:,:,c. I... ] * B[:,:,c] + β* D[:,:,c]
114
+ end
115
+
116
+ @test D ≈ C
117
+ end
118
+ end
119
+
120
+ m = 1
121
+ n = 2
122
+ # error when n == 1 and batchsize > 1 as cusparseSpMM fallsback to cusparseSpMV, which doesn't do batched computations.
123
+ # see https://docs.nvidia.com/cuda/cusparse/#cusparsespmm
124
+ k = 1
125
+ p = 1.
126
+
127
+ @testset " Sparse-Dense $elty bmm! for small matrices" for elty in (Float64, Float32, ComplexF64, ComplexF32)
128
+ # check if #2296 returns
129
+ α = rand (elty)
130
+ β = rand (elty)
131
+
132
+ @testset " C = αAB + βC" begin
133
+ A1 = CuSparseMatrixCSR {elty} (sprand (elty, m, k, p))
134
+ A2 = copy (A1)
135
+ A2. nzVal = CUDA. rand (elty, size (A2. nzVal)... )
136
+ A = cat (A1, A2; dims= 3 )
137
+
138
+ B = CUDA. rand (elty, k, n, 2 )
139
+ C = CUDA. rand (elty, m, n, 2 )
140
+ D = copy (C)
141
+
142
+ CUSPARSE. bmm! (' N' , ' N' , α, A, B, β, C, ' O' )
143
+
144
+ D[:,:,1 ] = α * A1 * B[:,:,1 ] + β * D[:,:,1 ]
145
+ D[:,:,2 ] = α * A2 * B[:,:,2 ] + β * D[:,:,2 ]
146
+
147
+ @test D ≈ C
148
+ end
149
+
150
+ @testset " C = αAᵀB + βC" begin
151
+ A1 = CuSparseMatrixCSR {elty} (sprand (elty, k, m, p))
152
+ A2 = copy (A1)
153
+ A2. nzVal = CUDA. rand (elty, size (A2. nzVal)... )
154
+ A = cat (A1, A2; dims= 3 )
155
+
156
+ B = CUDA. rand (elty, k, n, 2 )
157
+ C = CUDA. rand (elty, m, n, 2 )
158
+ D = copy (C)
159
+
160
+ CUSPARSE. bmm! (' C' , ' N' , α, A, B, β, C, ' O' )
161
+
162
+ D[:,:,1 ] = α * A1' * B[:,:,1 ] + β * D[:,:,1 ]
163
+ D[:,:,2 ] = α * A2' * B[:,:,2 ] + β * D[:,:,2 ]
164
+
165
+ @test D ≈ C
166
+ end
167
+
168
+
169
+ @testset " C = αABᵀ + βC" begin
170
+ A1 = CuSparseMatrixCSR {elty} (sprand (elty, m, k, p))
171
+ A2 = copy (A1)
172
+ A2. nzVal = CUDA. rand (elty, size (A2. nzVal)... )
173
+ A = cat (A1, A2; dims= 3 )
174
+
175
+ B = CUDA. rand (elty, n, k, 2 )
176
+ C = CUDA. rand (elty, m, n, 2 )
177
+ D = copy (C)
178
+
179
+ CUSPARSE. bmm! (' N' , ' C' , α, A, B, β, C, ' O' )
180
+
181
+ D[:,:,1 ] = α * A1 * B[:,:,1 ]' + β * D[:,:,1 ]
182
+ D[:,:,2 ] = α * A2 * B[:,:,2 ]' + β * D[:,:,2 ]
183
+
184
+ @test D ≈ C
185
+ end
186
+
187
+
74
188
@testset " C = αAᵀBᵀ + βC" begin
75
189
A1 = CuSparseMatrixCSR {elty} (sprand (elty, k, m, p))
76
190
A2 = copy (A1)
0 commit comments