Skip to content

Commit fe2070b

Browse files
committed
Updated test_decorators to include some heuristic backend
1 parent 2f15fc4 commit fe2070b

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

tests/utils/test_decorators.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,27 @@ def _cutlass_check(x, backend):
344344
def _cudnn_check(x, backend):
345345
return x.shape[0] > 5
346346

347-
@backend_requirement({"cutlass": _cutlass_check, "cudnn": _cudnn_check})
347+
# When using an auto backend, some heuristic function must exist
348+
def _heuristic_func(suitable_backends, x, backend):
349+
candidate_backends = None
350+
if x.shape[0] > 5:
351+
candidate_backends = ["cudnn", "cutlass"]
352+
else:
353+
candidate_backends = ["cutlass", "cudnn"]
354+
355+
heuristic_backends = []
356+
for backend in candidate_backends:
357+
if backend in suitable_backends:
358+
heuristic_backends.append(backend)
359+
return heuristic_backends
360+
361+
@backend_requirement(
362+
backend_checks={
363+
"cutlass": _cutlass_check,
364+
"cudnn": _cudnn_check,
365+
},
366+
heuristic_func=_heuristic_func,
367+
)
348368
def my_kernel(x, backend="auto"):
349369
backends = my_kernel.suitable_auto_backends
350370
if x.shape[0] > 5:

0 commit comments

Comments
 (0)