@@ -32,20 +32,21 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
3232try :
3333 if torch .cuda .is_available ():
3434 from torch .nn .attention import SDPBackend , sdpa_kernel
35-
36- SDPA_BACKEND_PRIORITY = [
37- SDPBackend .FLASH_ATTENTION ,
38- SDPBackend .EFFICIENT_ATTENTION ,
39- SDPBackend .MATH ,
40- ]
41-
42- SDPA_BACKEND_PRIORITY .insert (0 , SDPBackend .CUDNN_ATTENTION )
43-
44- def scaled_dot_product_attention (q , k , v , * args , ** kwargs ):
45- # Use this (rather than the decorator syntax) to eliminate graph
46- # break for pytorch < 2.9
47- with sdpa_kernel (SDPA_BACKEND_PRIORITY , set_priority = True ):
48- return torch .nn .functional .scaled_dot_product_attention (q , k , v , * args , ** kwargs )
35+ import inspect
36+ if "set_priority" in inspect .signature (sdpa_kernel ).parameters :
37+ SDPA_BACKEND_PRIORITY = [
38+ SDPBackend .FLASH_ATTENTION ,
39+ SDPBackend .EFFICIENT_ATTENTION ,
40+ SDPBackend .MATH ,
41+ ]
42+
43+ SDPA_BACKEND_PRIORITY .insert (0 , SDPBackend .CUDNN_ATTENTION )
44+
45+ def scaled_dot_product_attention (q , k , v , * args , ** kwargs ):
46+ with sdpa_kernel (SDPA_BACKEND_PRIORITY , set_priority = True ):
47+ return torch .nn .functional .scaled_dot_product_attention (q , k , v , * args , ** kwargs )
48+ else :
49+ logging .warning ("Torch version too old to set sdpa backend priority." )
4950except (ModuleNotFoundError , TypeError ):
5051 logging .warning ("Could not set sdpa backend priority." )
5152
0 commit comments