Skip to content

Commit f0d5d01

Browse files
authored
Avoid torch compile graphbreak for older pytorch versions (Comfy-Org#9344)
Turns out torch.compile has some gaps in context manager decorator syntax support. I've sent patches to fix that in PyTorch, but it won't be available for all the folks running older versions of PyTorch, hence this trivial patch.
1 parent ad19a06 commit f0d5d01

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

comfy/ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
4141

4242
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
4343

44-
@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
4544
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
46-
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
45+
# Use this (rather than the decorator syntax) to eliminate graph
46+
# break for pytorch < 2.9
47+
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
48+
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
4749
except (ModuleNotFoundError, TypeError):
4850
logging.warning("Could not set sdpa backend priority.")
4951

0 commit comments

Comments
 (0)