Skip to content

Commit 8b617f8

Browse files
eqypytorchmergebot
authored andcommitted
[cuBLAS] Add an option to disable reduced precision reductions for BF16 GEMM (pytorch#89172)
Essentially the same change as pytorch#67946, except that the default is to disallow reduced precision reductions in `BFloat16` GEMMs (for now). If performance is severely regressed, we can change the default, but this option appears to be necessary to pass some `addmm` `BFloat16` tests on H100. CC @ptrblck @ngimel Pull Request resolved: pytorch#89172 Approved by: https://github.com/ngimel
1 parent 1c7e815 commit 8b617f8

File tree

10 files changed

+95
-4
lines changed

10 files changed

+95
-4
lines changed

aten/src/ATen/Context.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,15 @@ void Context::setAllowFP16ReductionCuBLAS(bool b) {
250250
allow_fp16_reduction_cublas = b;
251251
}
252252

253+
bool Context::allowBF16ReductionCuBLAS() const {
254+
return allow_bf16_reduction_cublas;
255+
}
256+
257+
void Context::setAllowBF16ReductionCuBLAS(bool b) {
258+
allow_bf16_reduction_cublas = b;
259+
}
260+
261+
253262
bool Context::hasMKL() {
254263
#if AT_MKL_ENABLED()
255264
return true;

aten/src/ATen/Context.h

+3
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@ class TORCH_API Context {
241241
void setFloat32MatmulPrecision(Float32MatmulPrecision p);
242242
bool allowFP16ReductionCuBLAS() const;
243243
void setAllowFP16ReductionCuBLAS(bool);
244+
bool allowBF16ReductionCuBLAS() const;
245+
void setAllowBF16ReductionCuBLAS(bool);
244246
at::QEngine qEngine() const;
245247
void setQEngine(at::QEngine e);
246248
static const std::vector<at::QEngine>& supportedQEngines();
@@ -288,6 +290,7 @@ class TORCH_API Context {
288290
int benchmark_limit_cudnn = 10;
289291
bool allow_tf32_cudnn = true;
290292
bool allow_fp16_reduction_cublas = true;
293+
bool allow_bf16_reduction_cublas = true;
291294
bool enabled_mkldnn = true;
292295
at::LinalgBackend linalg_preferred_backend = at::LinalgBackend::Default;
293296
#ifdef C10_MOBILE

aten/src/ATen/cuda/CUDABlas.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,11 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
538538
float fbeta = beta;
539539
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
540540
GEMM_CHECK_ARGVALUES(at::BFloat16);
541+
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
542+
if (!at::globalContext().allowBF16ReductionCuBLAS()) {
543+
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
544+
}
545+
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
541546
TORCH_CUDABLAS_CHECK(cublasGemmEx(
542547
handle,
543548
opa,
@@ -558,6 +563,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
558563
ldc,
559564
CUDA_R_32F,
560565
CUBLAS_GEMM_DFALT_TENSOR_OP));
566+
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
561567
}
562568
#endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000
563569

docs/source/backends.rst

+4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ torch.backends.cuda
3434

3535
A :class:`bool` that controls whether reduced precision reductions (e.g., with fp16 accumulation type) are allowed with fp16 GEMMs.
3636

37+
.. attribute:: torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
38+
39+
A :class:`bool` that controls whether reduced precision reductions are allowed with bf16 GEMMs.
40+
3741
.. attribute:: torch.backends.cuda.cufft_plan_cache
3842

3943
``cufft_plan_cache`` caches the cuFFT plans

docs/source/notes/cuda.rst

+23-1
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,34 @@ If full precision reductions are needed, users can disable reduced precision red
169169
170170
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
171171
172-
To toggle the reduced precision reduction flags in C++, you can do
172+
To toggle the reduced precision reduction flags in C++, one can do
173173

174174
.. code:: C++
175175

176176
at::globalContext().setAllowFP16ReductionCuBLAS(false);
177177

178+
.. _bf16reducedprecision:
179+
180+
Reduced Precision Reduction in BF16 GEMMs
181+
-----------------------------------------
182+
183+
A similar flag (as above) exists for BFloat16 GEMMs. Note that this switch is
184+
set to `False` by default for BF16 as we have observed numerical instability in
185+
PyTorch CI tests (e.g., test/test_matmul_cuda.py).
186+
187+
If reduced precision reductions are not desired, users can disable reduced
188+
precision reductions in bf16 GEMMs with:
189+
190+
.. code:: python
191+
192+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
193+
194+
To toggle the reduced precision reduction flags in C++, one can do
195+
196+
.. code:: C++
197+
198+
at::globalContext().setAllowBF16ReductionCuBLAS(true);
199+
178200
Asynchronous execution
179201
----------------------
180202

docs/source/notes/numerical_accuracy.rst

+7-3
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,17 @@ If your network needs full float32 precision for both matrix multiplications and
9898

9999
For more information see :ref:`TensorFloat32<tf32_on_ampere>`.
100100

101-
Reduced Precision Reduction for FP16 GEMMs
102-
------------------------------------------
101+
Reduced Precision Reduction for FP16 and BF16 GEMMs
102+
----------------------------------------------------
103103
Half-precision GEMM operations are typically done with intermediate accumulations (reduction) in single-precision for numerical accuracy and improved resilience to overflow. For performance, certain GPU architectures, especially more recent ones, allow a few truncations of the intermediate accumulation results to the reduced precision (e.g., half-precision). This change is often benign from the perspective of model convergence, though it may lead to unexpected results (e.g., ``inf`` values when the final result should be be representable in half-precision).
104104
If reduced-precision reductions are problematic, they can be turned off with
105105
``torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False``
106106

107-
For more information see :ref:`allow_fp16_reduced_precision_reduction<fp16reducedprecision>`
107+
A similar flag exists for BF16 GEMM operations and is turned off by default. If BF16
108+
reduced-precision reductions are problematic, they can be turned off with
109+
``torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False``
110+
111+
For more information see :ref:`allow_fp16_reduced_precision_reduction<fp16reducedprecision>` and :ref:`allow_bf16_reduced_precision_reduction<bf16reducedprecision>`
108112

109113
.. _fp16_on_mi200:
110114

test/test_cuda.py

+8
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,14 @@ def test_cublas_allow_fp16_reduced_precision_reduction_get_set(self):
636636
self.assertEqual(torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), not orig)
637637
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig
638638

639+
def test_cublas_allow_bf16_reduced_precision_reduction_get_set(self):
640+
orig = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
641+
self.assertEqual(torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), orig)
642+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = not orig
643+
self.assertEqual(torch._C._get_cublas_allow_bf16_reduced_precision_reduction(), not orig)
644+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig
645+
646+
639647
def test_cudnn_allow_tf32_get_set(self):
640648
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
641649
self.assertFalse(torch.backends.cudnn.allow_tf32)

torch/_C/__init__.pyi.in

+2
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,8 @@ def _get_float32_matmul_precision() -> str: ... #THPModule_float32MatmulPrecisio
845845
def _set_float32_matmul_precision(arg: str) -> None: ... #THPModule_setFloat32MatmulPrecision
846846
def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... #THPModule_allowFP16ReductionCuBLAS
847847
def _set_cublas_allow_fp16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowFP16ReductionCuBLAS
848+
def _get_cublas_allow_bf16_reduced_precision_reduction() -> _bool: ... #THPModule_allowBF16ReductionCuBLAS
849+
def _set_cublas_allow_bf16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowBF16ReductionCuBLAS
848850
def _set_conj(x: Tensor, conj: _bool) -> None: ...
849851
def _set_neg(x: Tensor, neg: _bool) -> None: ...
850852
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...

torch/backends/cuda/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,17 @@ def __getattr__(self, name):
9797
return torch._C._get_cublas_allow_tf32()
9898
elif name == "allow_fp16_reduced_precision_reduction":
9999
return torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
100+
elif name == "allow_bf16_reduced_precision_reduction":
101+
return torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
100102
raise AssertionError("Unknown attribute " + name)
101103

102104
def __setattr__(self, name, value):
103105
if name == "allow_tf32":
104106
return torch._C._set_cublas_allow_tf32(value)
105107
elif name == "allow_fp16_reduced_precision_reduction":
106108
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(value)
109+
elif name == "allow_bf16_reduced_precision_reduction":
110+
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(value)
107111
raise AssertionError("Unknown attribute " + name)
108112

109113
_LinalgBackends = {

torch/csrc/Module.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,27 @@ PyObject* THPModule_allowFP16ReductionCuBLAS(
743743
Py_RETURN_FALSE;
744744
}
745745

746+
PyObject* THPModule_setAllowBF16ReductionCuBLAS(
747+
PyObject* _unused,
748+
PyObject* arg) {
749+
THPUtils_assert(
750+
PyBool_Check(arg),
751+
"set_allow_bf16_reduction_cublas expects a bool, "
752+
"but got %s",
753+
THPUtils_typename(arg));
754+
at::globalContext().setAllowBF16ReductionCuBLAS(arg == Py_True);
755+
Py_RETURN_NONE;
756+
}
757+
758+
PyObject* THPModule_allowBF16ReductionCuBLAS(
759+
PyObject* _unused,
760+
PyObject* noargs) {
761+
if (at::globalContext().allowBF16ReductionCuBLAS()) {
762+
Py_RETURN_TRUE;
763+
}
764+
Py_RETURN_FALSE;
765+
}
766+
746767
PyObject* THPModule_setFlushDenormal(PyObject* _unused, PyObject* arg) {
747768
THPUtils_assert(
748769
PyBool_Check(arg),
@@ -1063,6 +1084,14 @@ static PyMethodDef TorchMethods[] = {
10631084
THPModule_setAllowFP16ReductionCuBLAS,
10641085
METH_O,
10651086
nullptr},
1087+
{"_get_cublas_allow_bf16_reduced_precision_reduction",
1088+
THPModule_allowBF16ReductionCuBLAS,
1089+
METH_NOARGS,
1090+
nullptr},
1091+
{"_set_cublas_allow_bf16_reduced_precision_reduction",
1092+
THPModule_setAllowBF16ReductionCuBLAS,
1093+
METH_O,
1094+
nullptr},
10661095
{"_vmapmode_increment_nesting",
10671096
THPModule_vmapmode_increment_nesting,
10681097
METH_NOARGS,

0 commit comments

Comments
 (0)