Skip to content

support flash-attn at torch backend #2257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras_hub/src/models/gemma/gemma_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _compute_attention(
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.cast(attention_mask, dtype="bool")
# Only pass soft cap if needed as not all keras versions support.
if self.logit_soft_cap:
if self.logit_soft_cap is not None:
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
else:
kwargs = {}
Expand Down
73 changes: 31 additions & 42 deletions keras_hub/src/models/mixtral/mixtral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self._num_query_heads = num_query_heads
self._num_key_value_heads = num_key_value_heads
self._sliding_window = sliding_window
self._dropout = dropout
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.sliding_window = sliding_window
self.dropout = dropout

self._num_key_value_groups = num_query_heads // num_key_value_heads
self._rope_max_wavelength = rope_max_wavelength
self.num_key_value_groups = num_query_heads // num_key_value_heads
self.rope_max_wavelength = rope_max_wavelength

self._kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)

self._rope_scaling_factor = rope_scaling_factor
self.rope_scaling_factor = rope_scaling_factor

def build(self, inputs_shape):
# Einsum variables:
Expand All @@ -51,12 +51,12 @@ def build(self, inputs_shape):
# v = num key/value heads
# h = head dim
self._hidden_dim = inputs_shape[-1]
self._head_dim = self._hidden_dim // self._num_query_heads
self._head_dim = self._hidden_dim // self.num_query_heads
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)

self.query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self._num_query_heads, self._head_dim),
output_shape=(None, self.num_query_heads, self._head_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="query",
Expand All @@ -67,7 +67,7 @@ def build(self, inputs_shape):
equation="bkm,mvh->bkvh",
output_shape=(
None,
self._num_key_value_heads,
self.num_key_value_heads,
self._head_dim,
),
kernel_initializer=self._kernel_initializer,
Expand All @@ -80,7 +80,7 @@ def build(self, inputs_shape):
equation="bkm,mvh->bkvh",
output_shape=(
None,
self._num_key_value_heads,
self.num_key_value_heads,
self._head_dim,
),
kernel_initializer=self._kernel_initializer,
Expand All @@ -89,31 +89,31 @@ def build(self, inputs_shape):
)
self.value_dense.build(inputs_shape)

self._softmax = keras.layers.Softmax(
self.softmax = keras.layers.Softmax(
axis=-1,
dtype="float32",
name="attention_softmax",
)

self._dropout_layer = keras.layers.Dropout(
rate=self._dropout,
self.dropout_layer = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
)

self._output_dense = keras.layers.EinsumDense(
self.output_dense = keras.layers.EinsumDense(
equation="bquh,uhm->bqm",
output_shape=(None, self._hidden_dim),
kernel_initializer=self._kernel_initializer,
dtype=self.dtype_policy,
name="attention_output",
)
self._output_dense.build(
(None, None, self._num_query_heads, self._head_dim)
self.output_dense.build(
(None, None, self.num_query_heads, self._head_dim)
)

self.rotary_embedding_layer = RotaryEmbedding(
max_wavelength=self._rope_max_wavelength,
scaling_factor=self._rope_scaling_factor,
max_wavelength=self.rope_max_wavelength,
scaling_factor=self.rope_scaling_factor,
dtype=self.dtype_policy,
)

Expand Down Expand Up @@ -168,39 +168,34 @@ def _compute_key_value(x):

# [batch_shape, seq_len, num_key_value_heads, head_dim]
# -> [batch_shape, seq_len, num_heads, head_dim]
key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self._num_key_value_groups, axis=2)
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)

attention_output = self._compute_attention(
query, key, value, attention_mask
)

attention_output = self._dropout_layer(
attention_output = self.dropout_layer(
attention_output, training=training
)

attention_output = self._output_dense(attention_output)
attention_output = self.output_dense(attention_output)

if cache is not None:
return attention_output, cache
return attention_output

def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
return self._softmax(
attention_scores, attention_mask[:, None, :, :]
)
return self._softmax(attention_scores)
return self.softmax(attention_scores, attention_mask[:, None, :, :])
return self.softmax(attention_scores)

def _use_fused_attention_op(self):
if not fused_attention_op_available():
return False
if self.dropout > 0.0:
return False
if running_on_gpu():
# GPU never supports softcap in the fused op.
if self.logit_soft_cap is not None:
return False
return gpu_supports_fused_attention_op()
elif running_on_tpu():
# TPU supports softcap with on keras >= 3.10.
Expand All @@ -215,18 +210,12 @@ def _compute_attention(self, query, key, value, attention_mask=None):
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.cast(attention_mask, dtype="bool")

if self.logit_soft_cap:
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
else:
kwargs = {}

attention_output = ops.dot_product_attention(
query,
key,
value,
mask=attention_mask,
scale=self._inv_norm_factor,
**kwargs,
)
return attention_output

Expand All @@ -249,15 +238,15 @@ def get_config(self):
config = super().get_config()
config.update(
{
"num_query_heads": self._num_query_heads,
"num_key_value_heads": self._num_key_value_heads,
"rope_max_wavelength": self._rope_max_wavelength,
"rope_scaling_factor": self._rope_scaling_factor,
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"kernel_initializer": keras.initializers.serialize(
self._kernel_initializer
),
"sliding_window": self._sliding_window,
"dropout": self._dropout,
"sliding_window": self.sliding_window,
"dropout": self.dropout,
}
)
return config
1 change: 1 addition & 0 deletions keras_hub/src/models/qwen_moe/qwen_moe_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self.rope_scaling_factor = rope_scaling_factor
self.use_sliding_window_attention = use_sliding_window_attention
self.sliding_window_size = sliding_window_size
self.logit_soft_cap = None

def build(self, inputs_shape):
# Einsum variables:
Expand Down
17 changes: 17 additions & 0 deletions keras_hub/src/utils/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,23 @@ def fused_attention_op_available():
)
return False
return True
elif (
hasattr(keras.config, "is_flash_attention_enabled")
and keras.config.backend() == "torch"
):
try:
from torch.backends.cuda import SDPAParams as SDPAParams
from torch.backends.cuda import (
can_use_flash_attention as can_use_flash_attention,
)
except ImportError:
logging.warning(
"Flash attention is not supported in your current PyTorch "
"version. Please update it by following the official guide: "
"https://pytorch.org/get-started/locally/"
)
return False
return True
else:
return False

Expand Down
Loading