Skip to content

Commit 26dd7eb

Browse files
Fix ace step nan issue on some hardware/pytorch configs. (Comfy-Org#12289)
1 parent e77b34d commit 26dd7eb

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

comfy/text_encoders/llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -651,10 +651,10 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
651651
mask = None
652652
if attention_mask is not None:
653653
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
654-
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min)
654+
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4)
655655

656656
if seq_len > 1:
657-
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min).triu_(1)
657+
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1)
658658
if mask is not None:
659659
mask += causal_mask
660660
else:

0 commit comments

Comments
 (0)