Skip to content

Commit 9df8792

Browse files
Make last PR not crash comfy on old pytorch. (Comfy-Org#9324)
1 parent 3da5a07 commit 9df8792

File tree

4 files changed

+27
-17
lines changed

4 files changed

+27
-17
lines changed

comfy/ldm/hunyuan3d/vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
178178

179179
class CrossAttentionProcessor:
180180
def __call__(self, attn, q, k, v):
181-
out = ops.scaled_dot_product_attention(q, k, v)
181+
out = comfy.ops.scaled_dot_product_attention(q, k, v)
182182
return out
183183

184184

comfy/ldm/modules/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
448448
mask = mask.unsqueeze(1)
449449

450450
if SDP_BATCH_LIMIT >= b:
451-
out = ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
451+
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
452452
if not skip_output_reshape:
453453
out = (
454454
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
461461
if mask.shape[0] > 1:
462462
m = mask[i : i + SDP_BATCH_LIMIT]
463463

464-
out[i : i + SDP_BATCH_LIMIT] = ops.scaled_dot_product_attention(
464+
out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention(
465465
q[i : i + SDP_BATCH_LIMIT],
466466
k[i : i + SDP_BATCH_LIMIT],
467467
v[i : i + SDP_BATCH_LIMIT],

comfy/ldm/modules/diffusionmodules/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def pytorch_attention(q, k, v):
285285
)
286286

287287
try:
288-
out = ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
288+
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
289289
out = out.transpose(2, 3).reshape(orig_shape)
290290
except model_management.OOM_EXCEPTION:
291291
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")

comfy/ops.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,31 @@
2323
import comfy.float
2424
import comfy.rmsnorm
2525
import contextlib
26-
from torch.nn.attention import SDPBackend, sdpa_kernel
2726

28-
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
2927

30-
SDPA_BACKEND_PRIORITY = [
31-
SDPBackend.FLASH_ATTENTION,
32-
SDPBackend.EFFICIENT_ATTENTION,
33-
SDPBackend.MATH,
34-
]
35-
if torch.cuda.is_available():
36-
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
28+
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
29+
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
30+
31+
32+
try:
33+
if torch.cuda.is_available():
34+
from torch.nn.attention import SDPBackend, sdpa_kernel
35+
36+
SDPA_BACKEND_PRIORITY = [
37+
SDPBackend.FLASH_ATTENTION,
38+
SDPBackend.EFFICIENT_ATTENTION,
39+
SDPBackend.MATH,
40+
]
41+
42+
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
43+
44+
@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
45+
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
46+
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
47+
except (ModuleNotFoundError, TypeError):
48+
logging.warning("Could not set sdpa backend priority.")
49+
50+
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
3751

3852
def cast_to_input(weight, input, non_blocking=False, copy=True):
3953
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@@ -258,10 +272,6 @@ def conv_nd(s, dims, *args, **kwargs):
258272
else:
259273
raise ValueError(f"unsupported dimensions: {dims}")
260274

261-
@staticmethod
262-
@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
263-
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
264-
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
265275

266276
class manual_cast(disable_weight_init):
267277
class Linear(disable_weight_init.Linear):

0 commit comments

Comments
 (0)