diff --git a/keras_hub/src/models/gemma/gemma_attention.py b/keras_hub/src/models/gemma/gemma_attention.py index 4f8c414eb8..f66a4506ce 100644 --- a/keras_hub/src/models/gemma/gemma_attention.py +++ b/keras_hub/src/models/gemma/gemma_attention.py @@ -152,7 +152,7 @@ def _compute_attention( attention_mask = ops.expand_dims(attention_mask, axis=1) attention_mask = ops.cast(attention_mask, dtype="bool") # Only pass soft cap if needed as not all keras versions support. - if self.logit_soft_cap: + if self.logit_soft_cap is not None: kwargs = {"attn_logits_soft_cap": self.logit_soft_cap} else: kwargs = {} diff --git a/keras_hub/src/models/mixtral/mixtral_attention.py b/keras_hub/src/models/mixtral/mixtral_attention.py index 080be18047..0cae75a21c 100644 --- a/keras_hub/src/models/mixtral/mixtral_attention.py +++ b/keras_hub/src/models/mixtral/mixtral_attention.py @@ -27,19 +27,19 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self._num_query_heads = num_query_heads - self._num_key_value_heads = num_key_value_heads - self._sliding_window = sliding_window - self._dropout = dropout + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.sliding_window = sliding_window + self.dropout = dropout - self._num_key_value_groups = num_query_heads // num_key_value_heads - self._rope_max_wavelength = rope_max_wavelength + self.num_key_value_groups = num_query_heads // num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength self._kernel_initializer = keras.initializers.get( clone_initializer(kernel_initializer) ) - self._rope_scaling_factor = rope_scaling_factor + self.rope_scaling_factor = rope_scaling_factor def build(self, inputs_shape): # Einsum variables: @@ -51,12 +51,12 @@ def build(self, inputs_shape): # v = num key/value heads # h = head dim self._hidden_dim = inputs_shape[-1] - self._head_dim = self._hidden_dim // self._num_query_heads + self._head_dim = self._hidden_dim // self.num_query_heads self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim) self.query_dense = keras.layers.EinsumDense( equation="bqm,muh->bquh", - output_shape=(None, self._num_query_heads, self._head_dim), + output_shape=(None, self.num_query_heads, self._head_dim), kernel_initializer=self._kernel_initializer, dtype=self.dtype_policy, name="query", @@ -67,7 +67,7 @@ def build(self, inputs_shape): equation="bkm,mvh->bkvh", output_shape=( None, - self._num_key_value_heads, + self.num_key_value_heads, self._head_dim, ), kernel_initializer=self._kernel_initializer, @@ -80,7 +80,7 @@ def build(self, inputs_shape): equation="bkm,mvh->bkvh", output_shape=( None, - self._num_key_value_heads, + self.num_key_value_heads, self._head_dim, ), kernel_initializer=self._kernel_initializer, @@ -89,31 +89,31 @@ def build(self, inputs_shape): ) self.value_dense.build(inputs_shape) - self._softmax = keras.layers.Softmax( + self.softmax = keras.layers.Softmax( axis=-1, dtype="float32", name="attention_softmax", ) - self._dropout_layer = keras.layers.Dropout( - rate=self._dropout, + self.dropout_layer = keras.layers.Dropout( + rate=self.dropout, dtype=self.dtype_policy, ) - self._output_dense = keras.layers.EinsumDense( + self.output_dense = keras.layers.EinsumDense( equation="bquh,uhm->bqm", output_shape=(None, self._hidden_dim), kernel_initializer=self._kernel_initializer, dtype=self.dtype_policy, name="attention_output", ) - self._output_dense.build( - (None, None, self._num_query_heads, self._head_dim) + self.output_dense.build( + (None, None, self.num_query_heads, self._head_dim) ) self.rotary_embedding_layer = RotaryEmbedding( - max_wavelength=self._rope_max_wavelength, - scaling_factor=self._rope_scaling_factor, + max_wavelength=self.rope_max_wavelength, + scaling_factor=self.rope_scaling_factor, dtype=self.dtype_policy, ) @@ -168,18 +168,18 @@ def _compute_key_value(x): # [batch_shape, seq_len, num_key_value_heads, head_dim] # -> [batch_shape, seq_len, num_heads, head_dim] - key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2) - value = ops.repeat(value, repeats=self._num_key_value_groups, axis=2) + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) attention_output = self._compute_attention( query, key, value, attention_mask ) - attention_output = self._dropout_layer( + attention_output = self.dropout_layer( attention_output, training=training ) - attention_output = self._output_dense(attention_output) + attention_output = self.output_dense(attention_output) if cache is not None: return attention_output, cache @@ -187,10 +187,8 @@ def _compute_key_value(x): def _masked_softmax(self, attention_scores, attention_mask=None): if attention_mask is not None: - return self._softmax( - attention_scores, attention_mask[:, None, :, :] - ) - return self._softmax(attention_scores) + return self.softmax(attention_scores, attention_mask[:, None, :, :]) + return self.softmax(attention_scores) def _use_fused_attention_op(self): if not fused_attention_op_available(): @@ -198,9 +196,6 @@ def _use_fused_attention_op(self): if self.dropout > 0.0: return False if running_on_gpu(): - # GPU never supports softcap in the fused op. - if self.logit_soft_cap is not None: - return False return gpu_supports_fused_attention_op() elif running_on_tpu(): # TPU supports softcap with on keras >= 3.10. @@ -215,18 +210,12 @@ def _compute_attention(self, query, key, value, attention_mask=None): attention_mask = ops.expand_dims(attention_mask, axis=1) attention_mask = ops.cast(attention_mask, dtype="bool") - if self.logit_soft_cap: - kwargs = {"attn_logits_soft_cap": self.logit_soft_cap} - else: - kwargs = {} - attention_output = ops.dot_product_attention( query, key, value, mask=attention_mask, scale=self._inv_norm_factor, - **kwargs, ) return attention_output @@ -249,15 +238,15 @@ def get_config(self): config = super().get_config() config.update( { - "num_query_heads": self._num_query_heads, - "num_key_value_heads": self._num_key_value_heads, - "rope_max_wavelength": self._rope_max_wavelength, - "rope_scaling_factor": self._rope_scaling_factor, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, "kernel_initializer": keras.initializers.serialize( self._kernel_initializer ), - "sliding_window": self._sliding_window, - "dropout": self._dropout, + "sliding_window": self.sliding_window, + "dropout": self.dropout, } ) return config diff --git a/keras_hub/src/models/qwen_moe/qwen_moe_attention.py b/keras_hub/src/models/qwen_moe/qwen_moe_attention.py index 1f270d032d..f55c00cfad 100644 --- a/keras_hub/src/models/qwen_moe/qwen_moe_attention.py +++ b/keras_hub/src/models/qwen_moe/qwen_moe_attention.py @@ -67,6 +67,7 @@ def __init__( self.rope_scaling_factor = rope_scaling_factor self.use_sliding_window_attention = use_sliding_window_attention self.sliding_window_size = sliding_window_size + self.logit_soft_cap = None def build(self, inputs_shape): # Einsum variables: diff --git a/keras_hub/src/utils/keras_utils.py b/keras_hub/src/utils/keras_utils.py index 21607ffccb..6b5a7ad55c 100644 --- a/keras_hub/src/utils/keras_utils.py +++ b/keras_hub/src/utils/keras_utils.py @@ -71,6 +71,23 @@ def fused_attention_op_available(): ) return False return True + elif ( + hasattr(keras.config, "is_flash_attention_enabled") + and keras.config.backend() == "torch" + ): + try: + from torch.backends.cuda import SDPAParams as SDPAParams + from torch.backends.cuda import ( + can_use_flash_attention as can_use_flash_attention, + ) + except ImportError: + logging.warning( + "Flash attention is not supported in your current PyTorch " + "version. Please update it by following the official guide: " + "https://pytorch.org/get-started/locally/" + ) + return False + return True else: return False