Skip to content

Commit 1af40d5

Browse files
eqypytorchmergebot
authored andcommitted
[cublas][cublasLt] Fall back to unfused addmm for 2-byte-aligned inputs (pytorch#92201)
Fix for this issue surfaced from the discuss forum: https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214 Note that PyTorch builds before pytorch#71200 should not be affected as there was no `cublasLt` dispatch path. Additionally, the provided repro has the quirk of using a 3D input, which means it will not dispatch to `cublasLt`-backed `addmm` until builds that include pytorch#72728. Changing the input to 2D by trivially removing the size `1` dimension will surface the failure on builds after pytorch#71200. Interestingly, the use-case where _all_ inputs are 2-byte aligned are supported (runs without crashing), but when some are > 2-byte and some are == 2-byte are not. This behavior suggests that the `cuBlastLt` heuristics are incorrect, as the heuristic function has visibility of the raw pointer values via the descriptors when it is called. We will follow up with `cuBlasLt` but this fix is needed to prevent unnecessary crashes for now. CC @ptrblck @ngimel Pull Request resolved: pytorch#92201 Approved by: https://github.com/ngimel
1 parent a74c8df commit 1af40d5

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

aten/src/ATen/native/cuda/Blas.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,18 @@ static bool getDisableAddmmCudaLt() {
146146
return false;
147147
}
148148

149+
uint8_t getAlignment(const Tensor &t) {
150+
// alignment are in bytes
151+
uint8_t alignment = 1;
152+
uintptr_t address = reinterpret_cast<uintptr_t>(t.data_ptr());
153+
for (; alignment < 4; alignment *= 2) {
154+
if (address % (alignment * 2)) {
155+
return alignment;
156+
}
157+
}
158+
return alignment;
159+
}
160+
149161
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) {
150162
// Make sure to keep addmm_cuda below in sync with this code; it
151163
// preflights a check to try to avoid actually needing to call
@@ -173,13 +185,26 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
173185
// leading dim >> rows when they are sliced from a large tensor
174186
// see fbcode/caffe2/test/test_linalg.py:test_corner_cases_of_cublasltmatmul
175187
if (!disable_addmm_cuda_lt) {
188+
auto self_alignment = getAlignment(self);
189+
auto mat1_alignment = getAlignment(mat1);
190+
auto mat2_alignment = getAlignment(mat2);
191+
// due to a heuristic bug, cuBlasLt requires all alignments > 2 or the same ( == 2)
192+
// should we err on the side of caution and remove the second dispatch path?
193+
bool alignment_ok = (self_alignment > 2 &&
194+
mat1_alignment > 2 &&
195+
mat2_alignment > 2) ||
196+
(self_alignment == 2 &&
197+
mat1_alignment == 2 &&
198+
mat2_alignment == 2);
199+
176200
useLtInterface = beta.toComplexDouble() == 1.0 && self.dim() == 1 &&
177201
result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] &&
178202
self.is_contiguous() &&
179203
(scalar_type == at::ScalarType::Double ||
180204
scalar_type == at::ScalarType::Float ||
181205
scalar_type == at::ScalarType::Half ||
182206
scalar_type == at::ScalarType::BFloat16) &&
207+
alignment_ok &&
183208
mat2_sizes[0] > 1 && mat2_sizes[1] > 1 &&
184209
mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
185210
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&

test/test_matmul_cuda.py

+12
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,18 @@ def test_cublas_addmm(self, size: int, dtype: torch.dtype):
100100
self.assertEqual(res_cpu, res_cuda)
101101
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig
102102

103+
@onlyCUDA
104+
def test_cublas_addmm_alignment(self):
105+
dtype = torch.half
106+
device = 'cuda'
107+
A = torch.rand((5120 * 2560 + 1), requires_grad=True, dtype=dtype, device=device)
108+
A = A[1:].reshape(5120, 2560)
109+
# check that heuristic does not fail on 2-byte alignment
110+
X = torch.rand((26, 1, 2560), requires_grad=True, dtype=dtype, device=device)
111+
B = torch.rand((5120), requires_grad=True, dtype=dtype, device=device)
112+
out = torch.nn.functional.linear(X, A, B)
113+
self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B)
114+
103115
@onlyCUDA
104116
@unittest.skipIf(not CUDA11OrLater, "Only CUDA 11+ is supported")
105117
@toleranceOverride({torch.float32: xtol(atol=1e-5, rtol=1e-5)})

0 commit comments

Comments
 (0)