File tree Expand file tree Collapse file tree 1 file changed +21
-1
lines changed Expand file tree Collapse file tree 1 file changed +21
-1
lines changed Original file line number Diff line number Diff line change @@ -344,7 +344,27 @@ def _cutlass_check(x, backend):
344344 def _cudnn_check (x , backend ):
345345 return x .shape [0 ] > 5
346346
347- @backend_requirement ({"cutlass" : _cutlass_check , "cudnn" : _cudnn_check })
347+ # When using an auto backend, some heuristic function must exist
348+ def _heuristic_func (suitable_backends , x , backend ):
349+ candidate_backends = None
350+ if x .shape [0 ] > 5 :
351+ candidate_backends = ["cudnn" , "cutlass" ]
352+ else :
353+ candidate_backends = ["cutlass" , "cudnn" ]
354+
355+ heuristic_backends = []
356+ for backend in candidate_backends :
357+ if backend in suitable_backends :
358+ heuristic_backends .append (backend )
359+ return heuristic_backends
360+
361+ @backend_requirement (
362+ backend_checks = {
363+ "cutlass" : _cutlass_check ,
364+ "cudnn" : _cudnn_check ,
365+ },
366+ heuristic_func = _heuristic_func ,
367+ )
348368 def my_kernel (x , backend = "auto" ):
349369 backends = my_kernel .suitable_auto_backends
350370 if x .shape [0 ] > 5 :
You can’t perform that action at this time.
0 commit comments