Skip to content

Add StableLM-3B 4E1T to Keras Hub #2151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,18 @@
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_text_to_image_preprocessor import (
StableDiffusion3TextToImagePreprocessor as StableDiffusion3TextToImagePreprocessor,
)
from keras_hub.src.models.stablelm.stablelm_backbone import (
StableLMBackbone as StableLMBackbone,
)
from keras_hub.src.models.stablelm.stablelm_causal_lm import (
StableLMCausalLM as StableLMCausalLM,
)
from keras_hub.src.models.stablelm.stablelm_causal_lm_preprocessor import (
StableLMCausalLMPreprocessor as StableLMCausalLMPreprocessor,
)
from keras_hub.src.models.stablelm.stablelm_tokenizer import (
StableLMTokenizer as StableLMTokenizer,
)
from keras_hub.src.models.t5.t5_backbone import T5Backbone as T5Backbone
from keras_hub.src.models.t5.t5_preprocessor import (
T5Preprocessor as T5Preprocessor,
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@
from keras_hub.src.models.siglip.siglip_tokenizer import (
SigLIPTokenizer as SigLIPTokenizer,
)
from keras_hub.src.models.stablelm.stablelm_tokenizer import (
StableLMTokenizer as StableLMTokenizer,
)
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer
from keras_hub.src.models.whisper.whisper_tokenizer import (
WhisperTokenizer as WhisperTokenizer,
Expand Down
247 changes: 247 additions & 0 deletions keras_hub/src/models/stablelm/stablelm_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import math

import keras
from keras import ops

from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.utils.keras_utils import clone_initializer
from keras_hub.src.utils.keras_utils import fused_attention_op_available


class StableLMAttention(keras.layers.Layer):
"""StableLMAttention layer.

This layer implements the attention mechanism for StableLM-3B4E1T, featuring
multi-head self-attention with partial rotary position embeddings applied
to a fraction of the head dimensions, as specified by `rotary_percentage`.
It is adapted from the LlamaAttention layer with modifications to align
with StableLM's official configuration.

Args:
num_query_heads: int. Number of attention heads for queries.
num_key_value_heads: int. Number of attention heads for keys and
values.
hidden_dim: int. Hidden dimension of the input.
rope_max_wavelength: float. Maximum wavelength for rotary embeddings
(default: 10000).
rope_scaling_factor: float. Scaling factor for rotary embeddings
(default: 1.0).
rotary_percentage: float. Percentage of head dimensions to apply
rotary embeddings (default: 0.25).
kernel_initializer: str or initializer. Initializer for dense layer
kernels (default: "glorot_uniform").
dropout: float. Dropout rate for attention scores (default: 0.0).
"""

def __init__(
self,
num_query_heads,
num_key_value_heads,
hidden_dim,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
rotary_percentage=0.25,
kernel_initializer="glorot_uniform",
dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_dim = hidden_dim
self.rotary_percentage = rotary_percentage
self.dropout = dropout
self.rope_max_wavelength = rope_max_wavelength
self.rope_scaling_factor = rope_scaling_factor
self.num_key_value_groups = num_query_heads // num_key_value_heads
self.kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)

def build(self, inputs_shape):
head_dim = self.hidden_dim // self.num_query_heads
self.rotary_dim = int(head_dim * self.rotary_percentage)
self._inv_norm_factor = 1.0 / math.sqrt(head_dim)

# Query projection (no bias)
self.query_dense = keras.layers.EinsumDense(
equation="bqm,muh->bquh",
output_shape=(None, self.num_query_heads, head_dim),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="query",
)
self.query_dense.build(inputs_shape)

# Key projection (no bias)
self.key_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(None, self.num_key_value_heads, head_dim),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="key",
)
self.key_dense.build(inputs_shape)

# Value projection (no bias)
self.value_dense = keras.layers.EinsumDense(
equation="bkm,mvh->bkvh",
output_shape=(None, self.num_key_value_heads, head_dim),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="value",
)
self.value_dense.build(inputs_shape)

# Softmax layer for attention scores
self._softmax = keras.layers.Softmax(
axis=-1,
dtype="float32",
name="attention_softmax",
)

# Dropout layer for attention scores
self.dropout_layer = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
)

# Output projection (without bias)
self.output_dense = keras.layers.EinsumDense(
equation="bquh,uhm->bqm",
output_shape=(None, self.hidden_dim),
kernel_initializer=self.kernel_initializer,
dtype=self.dtype_policy,
name="attention_output",
)
self.output_dense.build((None, None, self.num_query_heads, head_dim))

# Rotary embedding layer
self.rotary_embedding_layer = RotaryEmbedding(
max_wavelength=self.rope_max_wavelength,
scaling_factor=self.rope_scaling_factor,
dtype=self.dtype_policy,
)

self._dot_product_equation = "bquh,bkuh->buqk"
self._combine_equation = "buqk,bkuh->bquh"
super().build(inputs_shape)

def call(
self,
hidden_states,
attention_mask=None,
cache=None,
cache_update_index=None,
training=None,
):
start_index = (
cache_update_index if cache_update_index is not None else 0
)
query = self.query_dense(hidden_states)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = self.rotary_embedding_layer(
query_rot, start_index=start_index
)
query = ops.concatenate([query_rot, query_pass], axis=-1)

def _compute_key_value(x):
key = self.key_dense(x)
value = self.value_dense(x)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = self.rotary_embedding_layer(
key_rot, start_index=start_index
)
key = ops.concatenate([key_rot, key_pass], axis=-1)
return key, value

# Handle caching for key and value
if cache is not None:
key_cache = cache[:, 0, ...]
value_cache = cache[:, 1, ...]
if cache_update_index is None:
key = key_cache
value = value_cache
else:
key_update, value_update = _compute_key_value(hidden_states)
start = [0, cache_update_index, 0, 0]
key = ops.slice_update(key_cache, start, key_update)
value = ops.slice_update(value_cache, start, value_update)
cache = ops.stack((key, value), axis=1)
else:
if cache_update_index is not None:
raise ValueError(
"`cache_update_index` should not be set if `cache` is "
f"`None`. Received: cache={cache}, "
f"cache_update_index={cache_update_index}"
)
key, value = _compute_key_value(hidden_states)

key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2)
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2)

attention_output = self._compute_attention(
query, key, value, attention_mask
)
attention_output = self.dropout_layer(
attention_output, training=training
)
attention_output = self.output_dense(attention_output)

return (
attention_output,
cache if cache is not None else attention_output,
)

def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
return self._softmax(
attention_scores, attention_mask[:, None, :, :]
)
return self._softmax(attention_scores)

def _compute_attention(self, query, key, value, attention_mask=None):
if fused_attention_op_available() and self.dropout == 0:
if attention_mask is not None:
attention_mask = ops.expand_dims(attention_mask, axis=1)
attention_mask = ops.cast(attention_mask, dtype="bool")
return ops.dot_product_attention(
query,
key,
value,
mask=attention_mask,
scale=self._inv_norm_factor,
)
attention_scores = ops.einsum(self._dot_product_equation, query, key)
attention_scores = ops.multiply(
attention_scores,
ops.cast(self._inv_norm_factor, self.compute_dtype),
)
attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_scores = ops.cast(attention_scores, self.compute_dtype)
attention_output = ops.einsum(
self._combine_equation, attention_scores, value
)
return attention_output

def get_config(self):
config = super().get_config()
config.update(
{
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"hidden_dim": self.hidden_dim,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"rotary_percentage": self.rotary_percentage,
"kernel_initializer": keras.initializers.serialize(
self.kernel_initializer
),
"dropout": self.dropout,
}
)
return config
Loading
Loading