Skip to content

Commit 2ad1406

Browse files
authored
support flash-attn at torch backend (#2257)
* support flash-attn at torch backend * fix * fix * fix * fix conflit * fix conflit * fix conflit * fix conflit * fix conflit * fix conflit * format
1 parent 127492a commit 2ad1406

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

keras_hub/src/models/gemma/gemma_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _compute_attention(
152152
attention_mask = ops.expand_dims(attention_mask, axis=1)
153153
attention_mask = ops.cast(attention_mask, dtype="bool")
154154
# Only pass soft cap if needed as not all keras versions support.
155-
if self.logit_soft_cap:
155+
if self.logit_soft_cap is not None:
156156
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
157157
else:
158158
kwargs = {}

keras_hub/src/models/qwen_moe/qwen_moe_attention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(
6767
self.rope_scaling_factor = rope_scaling_factor
6868
self.use_sliding_window_attention = use_sliding_window_attention
6969
self.sliding_window_size = sliding_window_size
70+
self.logit_soft_cap = None
7071

7172
def build(self, inputs_shape):
7273
# Einsum variables:

keras_hub/src/utils/keras_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,23 @@ def fused_attention_op_available():
7171
)
7272
return False
7373
return True
74+
elif (
75+
hasattr(keras.config, "is_flash_attention_enabled")
76+
and keras.config.backend() == "torch"
77+
):
78+
try:
79+
from torch.backends.cuda import SDPAParams as SDPAParams
80+
from torch.backends.cuda import (
81+
can_use_flash_attention as can_use_flash_attention,
82+
)
83+
except ImportError:
84+
logging.warning(
85+
"Flash attention is not supported in your current PyTorch "
86+
"version. Please update it by following the official guide: "
87+
"https://pytorch.org/get-started/locally/"
88+
)
89+
return False
90+
return True
7491
else:
7592
return False
7693

0 commit comments

Comments
 (0)