Skip to content

Commit

Permalink
Clear output of Torch SDPA for masked pieces
Browse files Browse the repository at this point in the history
Since Torch 2.1, the Torch memory-efficient SDPA GPU kernel returns NaN
for pieces that are completely masked out. This leads to NaN propagation
in the next attention layer, because masked pieces get an attention of
zero, but zero times NaN is still NaN.

In this we fix this by setting masked tokens to zero to clear out any
NaNs.

We currently rely on the query dimension of the mask to be singular, but
in the future we should probably redesign the `AttentionMask` class to
account for the differences between attention masks and causal masks.
  • Loading branch information
danieldk committed Feb 8, 2024
1 parent dfe6d96 commit ade1497
Showing 1 changed file with 42 additions and 12 deletions.
54 changes: 42 additions & 12 deletions curated_transformers/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ def forward(
key: Tensor,
value: Tensor,
attention_mask: AttentionMask,
use_causal_mask: bool,
) -> Tensor:
"""
Apply attention scores to the given key, query and value.
Expand All @@ -669,6 +670,8 @@ def forward(
Attention mask. Sequence elements for which the corresponding mask
element is set to ``False`` are ignored in attention.
:param use_causal_mask:
Mask out succeeding sequence elements when ``True``.
:returns:
Attention values.
Expand Down Expand Up @@ -712,27 +715,58 @@ def forward(
key: Tensor,
value: Tensor,
attention_mask: AttentionMask,
use_causal_mask: bool,
) -> Tensor:
combined_mask = attention_mask
if use_causal_mask:
causal_mask = create_causal_mask(query, key)
combined_mask = combined_mask.merge_mask(causal_mask)

if _TORCH_SDP.get():
attn_mask = attention_mask.logit_mask(query.dtype)
logit_mask = combined_mask.logit_mask(query.dtype)

# Add AliBi to the logit mask
if self.linear_biases is not None:
biases = self.linear_biases.calculate_biases(key.size(-2)).to(
dtype=query.dtype, device=query.device
)
bool_mask = attention_mask.bool_mask
attn_mask = torch.where(bool_mask, biases, attn_mask)
bool_mask = combined_mask.bool_mask
logit_mask = torch.where(bool_mask, biases, logit_mask)

# We can't pass a bool mask, because it is currently broken:
# https://github.com/pytorch/pytorch/issues/103749
return F.scaled_dot_product_attention(
attn_values = F.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
attn_mask=logit_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
)

# Torch SDP returns NaNs for pieces where every is masked out.
# These errort propagate, because zero attention times NaN is
# NaN. Since the representations of these tokens don't matter
# anyway, we will just zero them out.
#
# One issue is that values have shape
#
# [batch_len, n_heads, key_len, hidden_size]
#
# whereas masks have the shape
#
# [batch_len, 1, query_len, key_len]
#
# So we can only do this when we have attention masks where
# the query length it not specified:
#
# [batch_len, 1, 1, key_len]
#
# Doing this properly requires a redesign of our AttentionMask
# class.
assert attention_mask.bool_mask.size(-2) == 1
return torch.where(
attention_mask.bool_mask.transpose(-1, -2), attn_values, 0.0
)
else:
width = key.shape[-1]
attn_scores = query @ key.transpose(-2, -1)
Expand All @@ -741,7 +775,7 @@ def forward(
if self.linear_biases is not None:
attn_scores = self.linear_biases(attention_scores=attn_scores)

attn_scores = attention_mask.apply_logit_mask(attn_scores)
attn_scores = combined_mask.apply_logit_mask(attn_scores)
attn_weights = attn_scores.softmax(dim=-1)
attn_values = self.dropout(attn_weights @ value)

Expand Down Expand Up @@ -903,16 +937,12 @@ def forward(
key = torch.cat([cache_k, key], dim=-2)
value = torch.cat([cache_v, value], dim=-2)

combined_mask = attention_mask
if use_causal_mask:
causal_mask = create_causal_mask(query, key)
combined_mask = combined_mask.merge_mask(causal_mask)

attn = self.attention_scorer(
query=query,
key=key,
value=value,
attention_mask=combined_mask,
attention_mask=attention_mask,
use_causal_mask=use_causal_mask,
)

attn = combine_heads(attn)
Expand Down

0 comments on commit ade1497

Please sign in to comment.