Skip to content

Commit a8232ee

Browse files
IvanYashchukfacebook-github-bot
authored andcommitted
Sparse CSR CUDA: Add block torch.addmv when mat is sparse (pytorch#68708)
Summary: Pull Request resolved: pytorch#68708 This PR adds block CSR matrix times dense vector multiplication. cc nikitaved pearu cpuhrsch IvanYashchuk ngimel Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D32647694 Pulled By: cpuhrsch fbshipit-source-id: a1c120691c4350284b156fe4259eda684b734b66
1 parent 6df7b75 commit a8232ee

File tree

4 files changed

+215
-30
lines changed

4 files changed

+215
-30
lines changed

aten/src/ATen/cuda/CUDASparseBlas.cpp

+81
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,87 @@ void bsrmm<c10::complex<double>>(
311311
ldc));
312312
}
313313

314+
template <>
315+
void bsrmv<float>(CUSPARSE_BSRMV_ARGTYPES(float)) {
316+
TORCH_CUDASPARSE_CHECK(cusparseSbsrmv(
317+
handle,
318+
dirA,
319+
transA,
320+
mb,
321+
nb,
322+
nnzb,
323+
alpha,
324+
descrA,
325+
bsrValA,
326+
bsrRowPtrA,
327+
bsrColIndA,
328+
blockDim,
329+
x,
330+
beta,
331+
y));
332+
}
333+
334+
template <>
335+
void bsrmv<double>(CUSPARSE_BSRMV_ARGTYPES(double)) {
336+
TORCH_CUDASPARSE_CHECK(cusparseDbsrmv(
337+
handle,
338+
dirA,
339+
transA,
340+
mb,
341+
nb,
342+
nnzb,
343+
alpha,
344+
descrA,
345+
bsrValA,
346+
bsrRowPtrA,
347+
bsrColIndA,
348+
blockDim,
349+
x,
350+
beta,
351+
y));
352+
}
353+
354+
template <>
355+
void bsrmv<c10::complex<float>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<float>)) {
356+
TORCH_CUDASPARSE_CHECK(cusparseCbsrmv(
357+
handle,
358+
dirA,
359+
transA,
360+
mb,
361+
nb,
362+
nnzb,
363+
reinterpret_cast<const cuComplex*>(alpha),
364+
descrA,
365+
reinterpret_cast<const cuComplex*>(bsrValA),
366+
bsrRowPtrA,
367+
bsrColIndA,
368+
blockDim,
369+
reinterpret_cast<const cuComplex*>(x),
370+
reinterpret_cast<const cuComplex*>(beta),
371+
reinterpret_cast<cuComplex*>(y)));
372+
}
373+
374+
template <>
375+
void bsrmv<c10::complex<double>>(
376+
CUSPARSE_BSRMV_ARGTYPES(c10::complex<double>)) {
377+
TORCH_CUDASPARSE_CHECK(cusparseZbsrmv(
378+
handle,
379+
dirA,
380+
transA,
381+
mb,
382+
nb,
383+
nnzb,
384+
reinterpret_cast<const cuDoubleComplex*>(alpha),
385+
descrA,
386+
reinterpret_cast<const cuDoubleComplex*>(bsrValA),
387+
bsrRowPtrA,
388+
bsrColIndA,
389+
blockDim,
390+
reinterpret_cast<const cuDoubleComplex*>(x),
391+
reinterpret_cast<const cuDoubleComplex*>(beta),
392+
reinterpret_cast<cuDoubleComplex*>(y)));
393+
}
394+
314395
} // namespace sparse
315396
} // namespace cuda
316397
} // namespace at

aten/src/ATen/cuda/CUDASparseBlas.h

+24
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,30 @@ void bsrmm<c10::complex<float>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<float>));
130130
template <>
131131
void bsrmm<c10::complex<double>>(CUSPARSE_BSRMM_ARGTYPES(c10::complex<double>));
132132

133+
#define CUSPARSE_BSRMV_ARGTYPES(scalar_t) \
134+
cusparseHandle_t handle, cusparseDirection_t dirA, \
135+
cusparseOperation_t transA, int mb, int nb, int nnzb, \
136+
const scalar_t *alpha, const cusparseMatDescr_t descrA, \
137+
const scalar_t *bsrValA, const int *bsrRowPtrA, const int *bsrColIndA, \
138+
int blockDim, const scalar_t *x, const scalar_t *beta, scalar_t *y
139+
140+
template <typename scalar_t>
141+
inline void bsrmv(CUSPARSE_BSRMV_ARGTYPES(scalar_t)) {
142+
TORCH_INTERNAL_ASSERT(
143+
false,
144+
"at::cuda::sparse::bsrmv: not implemented for ",
145+
typeid(scalar_t).name());
146+
}
147+
148+
template <>
149+
void bsrmv<float>(CUSPARSE_BSRMV_ARGTYPES(float));
150+
template <>
151+
void bsrmv<double>(CUSPARSE_BSRMV_ARGTYPES(double));
152+
template <>
153+
void bsrmv<c10::complex<float>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<float>));
154+
template <>
155+
void bsrmv<c10::complex<double>>(CUSPARSE_BSRMV_ARGTYPES(c10::complex<double>));
156+
133157
} // namespace sparse
134158
} // namespace cuda
135159
} // namespace at

aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp

+67
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,70 @@ void inline col_indices_and_values_resize_(const Tensor& input, int64_t nnz) {
109109
input.sizes());
110110
}
111111

112+
void block_sparse_mv(
113+
const at::sparse_csr::SparseCsrTensor& mat,
114+
const Tensor& vec,
115+
const Scalar& beta,
116+
const Scalar& alpha,
117+
const Tensor& result) {
118+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat.is_sparse_csr());
119+
// values is expected to be a blocks of sparse matrix
120+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat.values().dim() == 3);
121+
// blocks are expected to be square
122+
TORCH_INTERNAL_ASSERT(mat.values().size(2) == mat.values().size(1));
123+
// only block of size > 1 is supported in cuSPARSE
124+
TORCH_INTERNAL_ASSERT(mat.values().size(-1) > 1);
125+
// blocks are expected to be in row- or column-major order
126+
TORCH_INTERNAL_ASSERT(
127+
mat.values().is_contiguous() ||
128+
mat.values().transpose(-2, -1).is_contiguous());
129+
130+
const cusparseDirection_t block_layout = mat.values().is_contiguous()
131+
? CUSPARSE_DIRECTION_ROW
132+
: CUSPARSE_DIRECTION_COLUMN;
133+
134+
c10::MaybeOwned<Tensor> result_ = prepare_dense_vector_for_cusparse(result);
135+
c10::MaybeOwned<Tensor> vec_ = prepare_dense_vector_for_cusparse(vec);
136+
137+
auto block_size = cuda_int_cast(mat.values().size(2), "block_size");
138+
auto nnzb = cuda_int_cast(mat._nnz(), "nnzb");
139+
auto mb = cuda_int_cast(mat.size(0), "mb") / block_size;
140+
auto nb = cuda_int_cast(mat.size(1), "nb") / block_size;
141+
142+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
143+
result.scalar_type(), "block_sparse_mv", [&] {
144+
auto beta_ = beta.to<scalar_t>();
145+
auto alpha_ = alpha.to<scalar_t>();
146+
auto handle = at::cuda::getCurrentCUDASparseHandle();
147+
auto desc = at::cuda::sparse::CuSparseMatDescriptor();
148+
auto values = mat.values();
149+
auto values_data_ptr = values.data_ptr<scalar_t>();
150+
auto crow_indices = mat.crow_indices().to(kInt);
151+
auto crow_indices_data_ptr = crow_indices.data_ptr<int>();
152+
auto col_indices = mat.col_indices().to(kInt);
153+
auto col_indices_data_ptr = col_indices.data_ptr<int>();
154+
at::cuda::sparse::bsrmv(
155+
handle,
156+
block_layout,
157+
CUSPARSE_OPERATION_NON_TRANSPOSE,
158+
mb,
159+
nb,
160+
nnzb,
161+
&alpha_,
162+
desc.descriptor(),
163+
values_data_ptr,
164+
crow_indices_data_ptr,
165+
col_indices_data_ptr,
166+
block_size,
167+
vec_->data_ptr<scalar_t>(),
168+
&beta_,
169+
result_->data_ptr<scalar_t>());
170+
});
171+
if (!result.is_same(*result_)) {
172+
result.copy_(*result_);
173+
}
174+
}
175+
112176
void block_sparse_mm(
113177
const at::sparse_csr::SparseCsrTensor& mat1,
114178
const Tensor& mat2,
@@ -500,6 +564,9 @@ void addmv_out_sparse_csr(
500564
const Scalar& beta,
501565
const Scalar& alpha,
502566
const Tensor& result) {
567+
if (mat.values().dim() == 3 && mat.values().size(-1) > 1) {
568+
return block_sparse_mv(mat, vec, beta, alpha, result);
569+
}
503570
#if !AT_USE_CUSPARSE_GENERIC_API()
504571
TORCH_CHECK(
505572
false,

test/test_sparse_csr.py

+43-30
Original file line numberDiff line numberDiff line change
@@ -585,48 +585,61 @@ def test_csr_matvec(self, device, dtype):
585585
with self.assertRaisesRegex(RuntimeError, err_msg):
586586
csr.matmul(bad_vec)
587587

588+
def run_test_block_addmm_addmv(self, addmv_addmm, c, a, b, op_b=False, op_out=False, *, dtype=None, device=None):
589+
alpha = complex(random.random(), random.random()) if dtype.is_complex else random.random()
590+
beta = complex(random.random(), random.random()) if dtype.is_complex else random.random()
591+
b = b.mH if (op_b and a.shape == b.shape) else b
592+
593+
actual = addmv_addmm(c, a, b, alpha=alpha, beta=beta)
594+
595+
out = torch.empty_like(c.mH if op_out and a.shape == b.shape else c)
596+
addmv_addmm(c, a, b, alpha=alpha, beta=beta, out=out)
597+
598+
a_bsr = sp.bsr_matrix(
599+
(
600+
a.values().cpu().numpy(),
601+
a.col_indices().cpu().numpy(),
602+
a.crow_indices().cpu().numpy(),
603+
),
604+
shape=a.shape,
605+
)
606+
expected = alpha * (a_bsr * b.cpu().numpy()) + beta * c.cpu().numpy()
607+
self.assertEqual(actual, out)
608+
self.assertEqual(actual, expected)
609+
588610
@onlyCUDA
589611
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
590612
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
591613
def test_block_addmm(self, device, dtype):
592-
def run_test(c, a, b, op_b, op_c, *, alpha=None, beta=None):
593-
if dtype.is_complex:
594-
alpha = random.random() + 0.3j if alpha is None else alpha
595-
beta = random.random() + 0.6j if beta is None else beta
596-
else:
597-
alpha = random.random() if alpha is None else alpha
598-
beta = random.random() if beta is None else beta
599-
600-
if op_b and a.shape == b.shape:
601-
b = b.mH
602-
603-
actual = torch.addmm(c, a, b, alpha=alpha, beta=beta)
604-
605-
out = torch.empty_like(c if op_c and a.shape == b.shape else c.mH)
606-
torch.addmm(c, a, b, alpha=alpha, beta=beta, out=out)
607-
608-
a_bsr = sp.bsr_matrix(
609-
(
610-
a.values().cpu().numpy(),
611-
a.col_indices().cpu().numpy(),
612-
a.crow_indices().cpu().numpy(),
613-
),
614-
shape=a.shape,
615-
)
616-
expected = alpha * (a_bsr * b.cpu().numpy()) + beta * c.cpu().numpy()
617-
self.assertEqual(actual, out)
618-
self.assertEqual(actual, expected)
619-
620614
for index_dtype in [torch.int32, torch.int64]:
621615
for (m, n, k), block_size, noncontiguous in zip(itertools.product([1, 5], repeat=3), [1, 2, 3], [True, False]):
622616
nnz = random.randint(0, m * k)
623617
a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
624618
a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
619+
a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks
625620
a = torch._sparse_csr_tensor_unsafe(a.crow_indices(), a.col_indices(), a_data, (m * block_size, k * block_size))
626621
b = make_tensor((k * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
627622
c = make_tensor((m * block_size, n * block_size), dtype=dtype, device=device, noncontiguous=noncontiguous)
628-
for op_b, op_c in itertools.product([True, False], repeat=2):
629-
run_test(c, a, b, op_b, op_c)
623+
for op_b, op_out in itertools.product([True, False], repeat=2):
624+
self.run_test_block_addmm_addmv(torch.addmm, c, a, b, op_b, op_out, dtype=dtype, device=device)
625+
626+
@onlyCUDA
627+
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
628+
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
629+
def test_block_addmv(self, device, dtype):
630+
for index_dtype in [torch.int32, torch.int64]:
631+
block_sizes = [1, 2, 3]
632+
if TEST_WITH_ROCM or not TEST_CUSPARSE_GENERIC:
633+
block_sizes = [2, 3]
634+
for (m, k), block_size, noncontiguous in zip(itertools.product([1, 5], repeat=2), block_sizes, [True, False]):
635+
nnz = random.randint(0, m * k)
636+
a = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
637+
a_data = make_tensor((nnz, block_size, block_size), dtype=dtype, device=device)
638+
a_data = a_data.mT if noncontiguous else a_data # Test column-major blocks
639+
a = torch._sparse_csr_tensor_unsafe(a.crow_indices(), a.col_indices(), a_data, (m * block_size, k * block_size))
640+
b = make_tensor((k * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous)
641+
c = make_tensor((m * block_size,), dtype=dtype, device=device, noncontiguous=noncontiguous)
642+
self.run_test_block_addmm_addmv(torch.addmv, c, a, b, dtype=dtype, device=device)
630643

631644

632645
@skipCPUIfNoMklSparse

0 commit comments

Comments
 (0)