Skip to content

Commit ae6dd20

Browse files
eqypytorchmergebot
authored andcommitted
[cuDNN V8 API] (reopen 2) Allow the number of kernels profiled under torch.backends.cudnn.benchmark = True to be limitedCudnnv8 benchmark limit (pytorch#78299)
Reopen of pytorch#77002 to address comments by @malfet CC @ngimel @ptrblck Pull Request resolved: pytorch#78299 Approved by: https://github.com/ngimel
1 parent 7fd0cf5 commit ae6dd20

File tree

8 files changed

+80
-6
lines changed

8 files changed

+80
-6
lines changed

aten/src/ATen/Context.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,14 @@ void Context::setBenchmarkCuDNN(bool b) {
144144
benchmark_cudnn = b;
145145
}
146146

147+
int Context::benchmarkLimitCuDNN() const {
148+
return benchmark_limit_cudnn;
149+
}
150+
151+
void Context::setBenchmarkLimitCuDNN(int b) {
152+
benchmark_limit_cudnn = b;
153+
}
154+
147155
bool Context::allowTF32CuBLAS() const {
148156
static bool allow_tf32_cublas_override = c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true;
149157
return allow_tf32_cublas_override || float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;

aten/src/ATen/Context.h

+3
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ class TORCH_API Context {
121121
void setUserEnabledMkldnn(bool e);
122122
bool benchmarkCuDNN() const;
123123
void setBenchmarkCuDNN(bool);
124+
int benchmarkLimitCuDNN() const;
125+
void setBenchmarkLimitCuDNN(int);
124126
bool deterministicCuDNN() const;
125127
void setDeterministicCuDNN(bool);
126128

@@ -254,6 +256,7 @@ class TORCH_API Context {
254256
bool benchmark_cudnn = false;
255257
Float32MatmulPrecision float32_matmul_precision =
256258
at::Float32MatmulPrecision::HIGHEST;
259+
int benchmark_limit_cudnn = 10;
257260
bool allow_tf32_cudnn = true;
258261
bool allow_fp16_reduction_cublas = true;
259262
bool enabled_mkldnn = true;

aten/src/ATen/native/cudnn/Conv_v8.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera
344344
remove_invalid = true;
345345
}
346346
}
347-
if (remove_invalid) {
347+
if (remove_invalid || max_plans) {
348348
cudnn_frontend::executionPlans_t new_valid_plans;
349349
unsigned int plan_count = 0;
350350
for (auto &plan : valid_plans) {
@@ -370,7 +370,8 @@ auto get_plans_from_find(const cudnnHandle_t handle, const cudnnBackendDescripto
370370
cudnn_frontend::executionPlans_t valid_plans;
371371
c10::DeviceGuard g(x.options().device());
372372
at::DataPtr workspace_ptr;
373-
generate_and_filter_plans(handle, opGraph, generator, x, valid_plans, workspace_ptr);
373+
auto benchmark_limit = at::globalContext().benchmarkLimitCuDNN();
374+
generate_and_filter_plans(handle, opGraph, generator, x, valid_plans, workspace_ptr, benchmark_limit);
374375
auto variantPack = cudnn_frontend::VariantPackBuilder()
375376
.setDataPointers(3, data_ptrs)
376377
.setUids(3, uids)
@@ -400,7 +401,8 @@ auto get_plans_from_find_fused(const cudnnHandle_t handle,
400401
cudnn_frontend::executionPlans_t valid_plans;
401402
c10::DeviceGuard g(x.options().device());
402403
at::DataPtr workspace_ptr;
403-
generate_and_filter_plans(handle, opGraph, generator, x, valid_plans, workspace_ptr);
404+
auto benchmark_limit = at::globalContext().benchmarkLimitCuDNN();
405+
generate_and_filter_plans(handle, opGraph, generator, x, valid_plans, workspace_ptr, benchmark_limit);
404406
auto variantPack = cudnn_frontend::VariantPackBuilder()
405407
.setDataPointers(5, data_ptrs)
406408
.setUids(5, uids)

docs/source/backends.rst

+8
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ torch.backends.cudnn
7878
A :class:`bool` that, if True, causes cuDNN to benchmark multiple convolution algorithms
7979
and select the fastest.
8080

81+
.. attribute:: torch.backends.cudnn.benchmark_limit
82+
83+
A :class:`int` that specifies the maximum number of cuDNN convolution algorithms to try when
84+
`torch.backends.cudnn.benchmark` is True. Set `benchmark_limit` to zero to try every
85+
available algorithm. Note that this setting only affects convolutions dispatched via the
86+
cuDNN v8 API.
87+
88+
8189
torch.backends.mps
8290
^^^^^^^^^^^^^^^^^^
8391
.. automodule:: torch.backends.mps

torch/CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ if(USE_ROCM)
142142
list(APPEND TORCH_PYTHON_INCLUDE_DIRECTORIES ${roctracer_INCLUDE_DIRS})
143143
endif()
144144

145+
if(USE_EXPERIMENTAL_CUDNN_V8_API)
146+
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_EXPERIMENTAL_CUDNN_V8_API)
147+
endif()
148+
145149
if(USE_CUDNN OR USE_ROCM)
146150
list(APPEND TORCH_PYTHON_SRCS
147151
${TORCH_SRC_DIR}/csrc/cuda/shared/cudnn.cpp

torch/_C/__init__.pyi.in

+2
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,8 @@ def _cuda_jiterator_compile_and_launch_kernel(code_string: str,
10531053
num_outputs: _int,
10541054
tensors: Tuple,
10551055
kwargs: Dict[str, Union[_int, _float, _bool]]) -> Tensor: ...
1056+
def _cuda_get_cudnn_benchmark_limit() -> _int: ...
1057+
def _cuda_set_cudnn_benchmark_limit(arg: _int) -> None: ...
10561058
def _nccl_version() -> _int: ...
10571059
def _nccl_unique_id() -> bytes: ...
10581060
def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ...

torch/backends/cudnn/__init__.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -102,15 +102,18 @@ def is_acceptable(tensor):
102102
return True
103103

104104

105-
def set_flags(_enabled=None, _benchmark=None, _deterministic=None, _allow_tf32=None):
105+
def set_flags(_enabled=None, _benchmark=None, _benchmark_limit=None, _deterministic=None, _allow_tf32=None):
106106
orig_flags = (torch._C._get_cudnn_enabled(),
107107
torch._C._get_cudnn_benchmark(),
108+
None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(),
108109
torch._C._get_cudnn_deterministic(),
109110
torch._C._get_cudnn_allow_tf32())
110111
if _enabled is not None:
111112
torch._C._set_cudnn_enabled(_enabled)
112113
if _benchmark is not None:
113114
torch._C._set_cudnn_benchmark(_benchmark)
115+
if _benchmark_limit is not None and is_available():
116+
torch._C._cuda_set_cudnn_benchmark_limit(_benchmark_limit)
114117
if _deterministic is not None:
115118
torch._C._set_cudnn_deterministic(_deterministic)
116119
if _allow_tf32 is not None:
@@ -119,9 +122,9 @@ def set_flags(_enabled=None, _benchmark=None, _deterministic=None, _allow_tf32=N
119122

120123

121124
@contextmanager
122-
def flags(enabled=False, benchmark=False, deterministic=False, allow_tf32=True):
125+
def flags(enabled=False, benchmark=False, benchmark_limit=10, deterministic=False, allow_tf32=True):
123126
with __allow_nonbracketed_mutation():
124-
orig_flags = set_flags(enabled, benchmark, deterministic, allow_tf32)
127+
orig_flags = set_flags(enabled, benchmark, benchmark_limit, deterministic, allow_tf32)
125128
try:
126129
yield
127130
finally:
@@ -141,6 +144,9 @@ def __init__(self, m, name):
141144
enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled)
142145
deterministic = ContextProp(torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic)
143146
benchmark = ContextProp(torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark)
147+
benchmark_limit = None
148+
if is_available():
149+
benchmark_limit = ContextProp(torch._C._cuda_get_cudnn_benchmark_limit, torch._C._cuda_set_cudnn_benchmark_limit)
144150
allow_tf32 = ContextProp(torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32)
145151

146152
# This is the sys.modules replacement trick, see
@@ -152,3 +158,4 @@ def __init__(self, m, name):
152158
deterministic: bool
153159
benchmark: bool
154160
allow_tf32: bool
161+
benchmark_limit: int

torch/csrc/cuda/Module.cpp

+40
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
#include <ATen/ATen.h>
2+
#include <ATen/cuda/CUDAConfig.h>
3+
#if AT_CUDNN_ENABLED()
4+
5+
#include <ATen/native/cudnn/Macros.h>
6+
7+
#endif
28
#include <ATen/cuda/CUDAContext.h>
39
#include <ATen/cuda/CUDAGeneratorImpl.h>
410
#include <ATen/cuda/CachingHostAllocator.h>
@@ -727,6 +733,32 @@ static PyObject* THCPModule_isCurrentStreamCapturing_wrap(
727733
END_HANDLE_TH_ERRORS
728734
}
729735

736+
PyObject* THCPModule_setBenchmarkLimitCuDNN(PyObject* _unused, PyObject* arg) {
737+
THPUtils_assert(
738+
THPUtils_checkLong(arg),
739+
"set_benchmark_limit_cudnn expects an int, "
740+
"but got %s",
741+
THPUtils_typename(arg));
742+
auto benchmark_limit = static_cast<int>(THPUtils_unpackLong(arg));
743+
#if defined(USE_ROCM)
744+
TORCH_WARN_ONCE(
745+
"cuDNN Benchmark limit is not supported in MIOpen and will have no effect.");
746+
#endif
747+
#if AT_CUDNN_ENABLED()
748+
#if HAS_CUDNN_V8()
749+
at::globalContext().setBenchmarkLimitCuDNN(benchmark_limit);
750+
#else
751+
TORCH_WARN_ONCE(
752+
"cuDNN Benchmark limit is not supported with cuDNN v7 API and will have no effect.");
753+
#endif
754+
#endif
755+
Py_RETURN_NONE;
756+
}
757+
758+
PyObject* THCPModule_benchmarkLimitCuDNN(PyObject* _unused, PyObject* noargs) {
759+
return THPUtils_packInt32(at::globalContext().benchmarkLimitCuDNN());
760+
}
761+
730762
// NOLINTNEXTLINE(modernize-avoid-c-arrays,
731763
// cppcoreguidelines-avoid-non-const-global-variables,
732764
// cppcoreguidelines-avoid-c-arrays)
@@ -814,6 +846,14 @@ static struct PyMethodDef _THCPModule_methods[] = {
814846
THCPModule_cudaJiteratorCompileAndLaunchKernel,
815847
METH_VARARGS,
816848
nullptr},
849+
{"_cuda_get_cudnn_benchmark_limit",
850+
THCPModule_benchmarkLimitCuDNN,
851+
METH_NOARGS,
852+
nullptr},
853+
{"_cuda_set_cudnn_benchmark_limit",
854+
THCPModule_setBenchmarkLimitCuDNN,
855+
METH_O,
856+
nullptr},
817857
#ifdef USE_NCCL
818858
{"_nccl_version", THCPModule_nccl_version, METH_NOARGS, nullptr},
819859
{"_nccl_unique_id", THCPModule_nccl_unique_id, METH_NOARGS, nullptr},

0 commit comments

Comments
 (0)