@@ -305,9 +305,20 @@ size_t get_available_workspace() {
305
305
return max_block_size;
306
306
}
307
307
308
+ static nlohmann::json errata_json_handle;
309
+
310
+ bool plan_errata_exception (const cudnnHandle_t handle, const std::string & executionPlanTag) {
311
+ static bool has_json = cudnn_frontend::load_from_config (errata_json_handle, " " );
312
+ if (!has_json) {
313
+ return false ;
314
+ } else {
315
+ return cudnn_frontend::check_errata (errata_json_handle, executionPlanTag, handle, [](){return true ;});
316
+ }
317
+ }
318
+
308
319
void generate_and_filter_plans (const cudnnHandle_t handle, cudnn_frontend::OperationGraph& opGraph, cudnn_frontend::EngineConfigGenerator& generator, const Tensor& x, cudnn_frontend::executionPlans_t& valid_plans, at::DataPtr& workspace_ptr, unsigned int max_plans = 0 ) {
309
320
auto initial_predicate_function = [&](cudnn_frontend::ExecutionPlan const & plan) -> bool {
310
- return false ;
321
+ return plan_errata_exception (handle, plan. getTag ()) ;
311
322
};
312
323
auto plans = generator.cudnnGetPlan (handle, opGraph, initial_predicate_function);
313
324
size_t max_block_size = get_available_workspace ();
@@ -407,8 +418,9 @@ auto get_plans_from_find_fused(const cudnnHandle_t handle,
407
418
408
419
409
420
// We only get configs from this stage to avoid building unnecessary plans that are never executed
410
- auto get_configs_from_heuristics (const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, const Tensor& x, const Tensor& y, const Tensor& w, const CacheKey& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
421
+ auto get_configs_from_heuristics (const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, std::string& opgraph_tag, const Tensor& x, const Tensor& y, const Tensor& w, const CacheKey& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
411
422
auto opGraph = build_opgraph (handle, desc, x, y, w, key, padding, stride, dilation);
423
+ opgraph_tag = opGraph.getTag ();
412
424
auto heuristic_mode = at::native::cudnnv8_use_heur_mode_b () ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_INSTANT;
413
425
auto sources = get_generator_sources (desc, x, deterministic, allow_tf32, heuristic_mode);
414
426
@@ -417,8 +429,9 @@ auto get_configs_from_heuristics(const cudnnHandle_t handle, const cudnnBackendD
417
429
return configs;
418
430
}
419
431
420
- auto get_configs_from_heuristics_fused (const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b, const float alpha, const CacheKeyFused& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
432
+ auto get_configs_from_heuristics_fused (const cudnnHandle_t handle, std::string& opgraph_tag, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b, const float alpha, const CacheKeyFused& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
421
433
auto opGraph = build_opgraph_fused (handle, x, y, w, z, b, alpha, key, padding, stride, dilation);
434
+ opgraph_tag = opGraph.getTag ();
422
435
auto heuristic_mode = at::native::cudnnv8_use_heur_mode_b () ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_INSTANT;
423
436
auto sources = get_generator_sources (CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, x, deterministic, allow_tf32, heuristic_mode);
424
437
@@ -455,13 +468,16 @@ void try_plans_fused(cudnn_frontend::executionPlans_t& plans, const CacheKeyFuse
455
468
TORCH_CHECK (false , " FIND was unable to find an engine to execute this computation" );
456
469
}
457
470
458
- void try_configs (cudnn_frontend::EngineConfigList& configs, const CacheKey& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w) {
471
+ void try_configs (cudnn_frontend::EngineConfigList& configs, const std::string& opgraph_tag, const CacheKey& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w) {
459
472
for (auto & config : configs) {
460
473
try {
461
474
auto plan = cudnn_frontend::ExecutionPlanBuilder ()
462
475
.setHandle (handle)
463
- .setEngineConfig (config)
476
+ .setEngineConfig (config, opgraph_tag )
464
477
.build ();
478
+ if (plan_errata_exception (handle, plan.getTag ())) {
479
+ continue ;
480
+ }
465
481
run_conv_plan (handle, x, y, w, plan);
466
482
benchmark_cache.emplace (key, plan);
467
483
return ;
@@ -473,13 +489,16 @@ void try_configs(cudnn_frontend::EngineConfigList& configs, const CacheKey& key,
473
489
TORCH_CHECK (false , " GET was unable to find an engine to execute this computation" );
474
490
}
475
491
476
- void try_configs_fused (cudnn_frontend::EngineConfigList& configs, const CacheKeyFused& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b) {
492
+ void try_configs_fused (cudnn_frontend::EngineConfigList& configs, const std::string& opgraph_tag, const CacheKeyFused& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b) {
477
493
for (auto & config : configs) {
478
494
try {
479
495
auto plan = cudnn_frontend::ExecutionPlanBuilder ()
480
496
.setHandle (handle)
481
- .setEngineConfig (config)
497
+ .setEngineConfig (config, opgraph_tag )
482
498
.build ();
499
+ if (plan_errata_exception (handle, plan.getTag ())) {
500
+ continue ;
501
+ }
483
502
run_conv_plan_fused (handle, x, y, w, z, b, plan);
484
503
benchmark_cache_fused.emplace (key, plan);
485
504
return ;
@@ -496,7 +515,6 @@ void run_single_conv(const cudnnBackendDescriptorType_t operation,
496
515
const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const int64_t groups,
497
516
const bool benchmark, const bool deterministic, const bool allow_tf32) {
498
517
cudnnHandle_t handle = getCudnnHandle ();
499
-
500
518
CacheKey key;
501
519
setCacheKey (key, operation, y, x, w, padding, stride, dilation, groups, deterministic, allow_tf32);
502
520
// TODO: is this thread safe if cache is updated? is pointer stale?
@@ -509,13 +527,14 @@ void run_single_conv(const cudnnBackendDescriptorType_t operation,
509
527
cudaGetLastError (); // clear CUDA error
510
528
}
511
529
}
512
-
513
530
if (!benchmark) {
531
+ std::string opgraph_tag; // extra data needed for errata filter
514
532
cudnn_frontend::EngineConfigList configs = get_configs_from_heuristics (handle, operation,
533
+ opgraph_tag,
515
534
x, y, w, key,
516
535
padding, stride, dilation,
517
536
deterministic, allow_tf32);
518
- try_configs (configs, key, handle, x, y, w);
537
+ try_configs (configs, opgraph_tag, key, handle, x, y, w);
519
538
} else {
520
539
cudnn_frontend::executionPlans_t plans = get_plans_from_find (handle, operation,
521
540
x, y, w, key,
@@ -544,13 +563,14 @@ void run_fused_conv(const Tensor& x, const Tensor& y, const Tensor& w, const Ten
544
563
cudaGetLastError (); // clear CUDA error
545
564
}
546
565
}
547
-
548
566
if (!benchmark) {
567
+ std::string opgraph_tag; // extra data needed for errata filter
549
568
cudnn_frontend::EngineConfigList configs = get_configs_from_heuristics_fused (handle,
569
+ opgraph_tag,
550
570
x, y, w, z, b, alpha, key,
551
571
padding, stride, dilation,
552
572
deterministic, allow_tf32);
553
- try_configs_fused (configs, key, handle, x, y, w, z, b);
573
+ try_configs_fused (configs, opgraph_tag, key, handle, x, y, w, z, b);
554
574
} else {
555
575
cudnn_frontend::executionPlans_t plans = get_plans_from_find_fused (handle,
556
576
x, y, w, z, b, alpha, key,
0 commit comments