Skip to content

Commit 4e5c230

Browse files
Fix last commit not working on older pytorch. (Comfy-Org#9346)
1 parent f0d5d01 commit 4e5c230

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

comfy/ops.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,21 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
3232
try:
3333
if torch.cuda.is_available():
3434
from torch.nn.attention import SDPBackend, sdpa_kernel
35-
36-
SDPA_BACKEND_PRIORITY = [
37-
SDPBackend.FLASH_ATTENTION,
38-
SDPBackend.EFFICIENT_ATTENTION,
39-
SDPBackend.MATH,
40-
]
41-
42-
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
43-
44-
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
45-
# Use this (rather than the decorator syntax) to eliminate graph
46-
# break for pytorch < 2.9
47-
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
48-
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
35+
import inspect
36+
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
37+
SDPA_BACKEND_PRIORITY = [
38+
SDPBackend.FLASH_ATTENTION,
39+
SDPBackend.EFFICIENT_ATTENTION,
40+
SDPBackend.MATH,
41+
]
42+
43+
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
44+
45+
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
46+
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
47+
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
48+
else:
49+
logging.warning("Torch version too old to set sdpa backend priority.")
4950
except (ModuleNotFoundError, TypeError):
5051
logging.warning("Could not set sdpa backend priority.")
5152

0 commit comments

Comments
 (0)