Skip to content

Commit a50c32d

Browse files
Disable sage attention on ace step 1.5 (Comfy-Org#12297)
1 parent 6125b80 commit a50c32d

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

comfy/ldm/ace/ace_step15.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def forward(
183183
else:
184184
attn_bias = window_bias
185185

186-
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True)
186+
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False)
187187
attn_output = self.o_proj(attn_output)
188188

189189
return attn_output

comfy/ldm/modules/attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
524524

525525
@wrap_attn
526526
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
527+
if kwargs.get("low_precision_attention", True) is False:
528+
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
529+
527530
exception_fallback = False
528531
if skip_reshape:
529532
b, _, _, dim_head = q.shape

0 commit comments

Comments
 (0)