Skip to content

Commit fc66521

Browse files
eqypytorchmergebot
authored andcommitted
[cuDNN] [cuDNN v8 API] Support cuDNN Errata Filter (pytorch#73934)
Not originally mentioned in the tracking issue pytorch#58414, but is a nice-to-have feature. In summary, the errata filter allows known problematic kernels to be skipped instead of irrecoverably crashing a CUDA context (e.g., via an illegal memory access) via a JSON file supplied at run time. cuDNN frontend description: https://github.com/NVIDIA/cudnn-frontend#errata-filter Sample errata filter JSON: ``` { "version" : 1, "rules" : [ { "rule_id" : "avoid_bad_bwd_data", "operation" : "ConvBwdData", "engine" : 12, "cudnn_version_start" : 8000, "cudnn_version_end" : 9000 } ] } ``` CC @ngimel @zasdfgbnm @ptrblck Pull Request resolved: pytorch#73934 Approved by: https://github.com/ngimel
1 parent c29df68 commit fc66521

File tree

1 file changed

+32
-12
lines changed

1 file changed

+32
-12
lines changed

aten/src/ATen/native/cudnn/Conv_v8.cpp

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,20 @@ size_t get_available_workspace() {
305305
return max_block_size;
306306
}
307307

308+
static nlohmann::json errata_json_handle;
309+
310+
bool plan_errata_exception(const cudnnHandle_t handle, const std::string & executionPlanTag) {
311+
static bool has_json = cudnn_frontend::load_from_config(errata_json_handle, "");
312+
if (!has_json) {
313+
return false;
314+
} else {
315+
return cudnn_frontend::check_errata(errata_json_handle, executionPlanTag, handle, [](){return true;});
316+
}
317+
}
318+
308319
void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::OperationGraph& opGraph, cudnn_frontend::EngineConfigGenerator& generator, const Tensor& x, cudnn_frontend::executionPlans_t& valid_plans, at::DataPtr& workspace_ptr, unsigned int max_plans = 0) {
309320
auto initial_predicate_function = [&](cudnn_frontend::ExecutionPlan const& plan) -> bool {
310-
return false;
321+
return plan_errata_exception(handle, plan.getTag());
311322
};
312323
auto plans = generator.cudnnGetPlan(handle, opGraph, initial_predicate_function);
313324
size_t max_block_size = get_available_workspace();
@@ -407,8 +418,9 @@ auto get_plans_from_find_fused(const cudnnHandle_t handle,
407418

408419

409420
// We only get configs from this stage to avoid building unnecessary plans that are never executed
410-
auto get_configs_from_heuristics(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, const Tensor& x, const Tensor& y, const Tensor& w, const CacheKey& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
421+
auto get_configs_from_heuristics(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, std::string& opgraph_tag, const Tensor& x, const Tensor& y, const Tensor& w, const CacheKey& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
411422
auto opGraph = build_opgraph(handle, desc, x, y, w, key, padding, stride, dilation);
423+
opgraph_tag = opGraph.getTag();
412424
auto heuristic_mode = at::native::cudnnv8_use_heur_mode_b() ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_INSTANT;
413425
auto sources = get_generator_sources(desc, x, deterministic, allow_tf32, heuristic_mode);
414426

@@ -417,8 +429,9 @@ auto get_configs_from_heuristics(const cudnnHandle_t handle, const cudnnBackendD
417429
return configs;
418430
}
419431

420-
auto get_configs_from_heuristics_fused(const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b, const float alpha, const CacheKeyFused& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
432+
auto get_configs_from_heuristics_fused(const cudnnHandle_t handle, std::string& opgraph_tag, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b, const float alpha, const CacheKeyFused& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
421433
auto opGraph = build_opgraph_fused(handle, x, y, w, z, b, alpha, key, padding, stride, dilation);
434+
opgraph_tag = opGraph.getTag();
422435
auto heuristic_mode = at::native::cudnnv8_use_heur_mode_b() ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_INSTANT;
423436
auto sources = get_generator_sources(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, x, deterministic, allow_tf32, heuristic_mode);
424437

@@ -455,13 +468,16 @@ void try_plans_fused(cudnn_frontend::executionPlans_t& plans, const CacheKeyFuse
455468
TORCH_CHECK(false, "FIND was unable to find an engine to execute this computation");
456469
}
457470

458-
void try_configs(cudnn_frontend::EngineConfigList& configs, const CacheKey& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w) {
471+
void try_configs(cudnn_frontend::EngineConfigList& configs, const std::string& opgraph_tag, const CacheKey& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w) {
459472
for (auto & config : configs) {
460473
try {
461474
auto plan = cudnn_frontend::ExecutionPlanBuilder()
462475
.setHandle(handle)
463-
.setEngineConfig(config)
476+
.setEngineConfig(config, opgraph_tag)
464477
.build();
478+
if (plan_errata_exception(handle, plan.getTag())) {
479+
continue;
480+
}
465481
run_conv_plan(handle, x, y, w, plan);
466482
benchmark_cache.emplace(key, plan);
467483
return;
@@ -473,13 +489,16 @@ void try_configs(cudnn_frontend::EngineConfigList& configs, const CacheKey& key,
473489
TORCH_CHECK(false, "GET was unable to find an engine to execute this computation");
474490
}
475491

476-
void try_configs_fused(cudnn_frontend::EngineConfigList& configs, const CacheKeyFused& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b) {
492+
void try_configs_fused(cudnn_frontend::EngineConfigList& configs, const std::string& opgraph_tag, const CacheKeyFused& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b) {
477493
for (auto & config : configs) {
478494
try {
479495
auto plan = cudnn_frontend::ExecutionPlanBuilder()
480496
.setHandle(handle)
481-
.setEngineConfig(config)
497+
.setEngineConfig(config, opgraph_tag)
482498
.build();
499+
if (plan_errata_exception(handle, plan.getTag())) {
500+
continue;
501+
}
483502
run_conv_plan_fused(handle, x, y, w, z, b, plan);
484503
benchmark_cache_fused.emplace(key, plan);
485504
return;
@@ -496,7 +515,6 @@ void run_single_conv(const cudnnBackendDescriptorType_t operation,
496515
const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const int64_t groups,
497516
const bool benchmark, const bool deterministic, const bool allow_tf32) {
498517
cudnnHandle_t handle = getCudnnHandle();
499-
500518
CacheKey key;
501519
setCacheKey(key, operation, y, x, w, padding, stride, dilation, groups, deterministic, allow_tf32);
502520
// TODO: is this thread safe if cache is updated? is pointer stale?
@@ -509,13 +527,14 @@ void run_single_conv(const cudnnBackendDescriptorType_t operation,
509527
cudaGetLastError(); // clear CUDA error
510528
}
511529
}
512-
513530
if (!benchmark) {
531+
std::string opgraph_tag; // extra data needed for errata filter
514532
cudnn_frontend::EngineConfigList configs = get_configs_from_heuristics(handle, operation,
533+
opgraph_tag,
515534
x, y, w, key,
516535
padding, stride, dilation,
517536
deterministic, allow_tf32);
518-
try_configs(configs, key, handle, x, y, w);
537+
try_configs(configs, opgraph_tag, key, handle, x, y, w);
519538
} else {
520539
cudnn_frontend::executionPlans_t plans = get_plans_from_find(handle, operation,
521540
x, y, w, key,
@@ -544,13 +563,14 @@ void run_fused_conv(const Tensor& x, const Tensor& y, const Tensor& w, const Ten
544563
cudaGetLastError(); // clear CUDA error
545564
}
546565
}
547-
548566
if (!benchmark) {
567+
std::string opgraph_tag; // extra data needed for errata filter
549568
cudnn_frontend::EngineConfigList configs = get_configs_from_heuristics_fused(handle,
569+
opgraph_tag,
550570
x, y, w, z, b, alpha, key,
551571
padding, stride, dilation,
552572
deterministic, allow_tf32);
553-
try_configs_fused(configs, key, handle, x, y, w, z, b);
573+
try_configs_fused(configs, opgraph_tag, key, handle, x, y, w, z, b);
554574
} else {
555575
cudnn_frontend::executionPlans_t plans = get_plans_from_find_fused(handle,
556576
x, y, w, z, b, alpha, key,

0 commit comments

Comments
 (0)