From 79b99c53ad3bd2e0ee31894df6bf2b329dfde494 Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Tue, 18 Mar 2025 22:21:52 +0530 Subject: [PATCH 01/10] stablelm_attention --- .../src/models/stablelm/stablelm_attention.py | 231 ++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 keras_hub/src/models/stablelm/stablelm_attention.py diff --git a/keras_hub/src/models/stablelm/stablelm_attention.py b/keras_hub/src/models/stablelm/stablelm_attention.py new file mode 100644 index 0000000000..1c9bef344b --- /dev/null +++ b/keras_hub/src/models/stablelm/stablelm_attention.py @@ -0,0 +1,231 @@ +import math +import keras +from keras import ops + +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.utils.keras_utils import clone_initializer +from keras_hub.src.utils.keras_utils import has_flash_attention_support + + +class StableLMAttention(keras.layers.Layer): + """StableLMAttention layer. + + This layer implements the attention mechanism for StableLM-3B4E1T, featuring + multi-head self-attention with partial rotary position embeddings applied + to a fraction of the head dimensions, as specified by `rotary_percentage`. + It is adapted from the LlamaAttention layer with modifications to align + with StableLM's official configuration(https://github.com/Stability-AI/StableLM/blob/main/configs/stablelm-3b-4e1t.yml). + + Args: + num_query_heads (int): Number of attention heads for queries. + num_key_value_heads (int): Number of attention heads for keys and values. + hidden_dim (int): Hidden dimension of the input (e.g., 2560 for StableLM-3B4E1T). + rope_max_wavelength (float): Maximum wavelength for rotary embeddings (default: 10000). + rope_scaling_factor (float): Scaling factor for rotary embeddings (default: 1.0). + rotary_percentage (float): Percentage of head dimensions to apply rotary embeddings (default: 0.25). + kernel_initializer (str or initializer): Initializer for dense layer kernels (default: "glorot_uniform"). + dropout (float): Dropout rate for attention scores (default: 0.0). + **kwargs: Additional keyword arguments passed to the parent class. + """ + + def __init__( + self, + num_query_heads, + num_key_value_heads, + hidden_dim, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + rotary_percentage=0.25, + kernel_initializer="glorot_uniform", + dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_dim = hidden_dim + self.rotary_percentage = rotary_percentage + self.dropout = dropout + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.num_key_value_groups = num_query_heads // num_key_value_heads + self.kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + + def build(self, inputs_shape): + head_dim = self.hidden_dim // self.num_query_heads + self.rotary_dim = int(head_dim * self.rotary_percentage) + self._inv_norm_factor = 1.0 / math.sqrt(head_dim) + + # Query projection (no bias ) + self._query_dense = keras.layers.EinsumDense( + equation="bqm,muh->bquh", + output_shape=(None, self.num_query_heads, head_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="query", + ) + self._query_dense.build(inputs_shape) + + # Key projection (no bias) + self._key_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=(None, self.num_key_value_heads, head_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="key", + ) + self._key_dense.build(inputs_shape) + + # Value projection (no bias) + self._value_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=(None, self.num_key_value_heads, head_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="value", + ) + self._value_dense.build(inputs_shape) + + # Softmax layer for attention scores + self._softmax = keras.layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) + + # Dropout layer for attention scores + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + + # Output projection (without bias) + self._output_dense = keras.layers.EinsumDense( + equation="bquh,uhm->bqm", + output_shape=(None, self.hidden_dim), + kernel_initializer=self.kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + self._output_dense.build((None, None, self.num_query_heads, head_dim)) + + # Rotary embedding layer + self.rotary_embedding_layer = RotaryEmbedding( + max_wavelength=self.rope_max_wavelength, + scaling_factor=self.rope_scaling_factor, + dtype=self.dtype_policy, + ) + + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" + self.built = True + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + start_index = cache_update_index if cache_update_index is not None else 0 + + # Compute query and apply partial rotary embedding + query = self._query_dense(hidden_states) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = self.rotary_embedding_layer(query_rot, start_index=start_index) + query = ops.concatenate([query_rot, query_pass], axis=-1) + + def _compute_key_value(x): + key = self._key_dense(x) + value = self._value_dense(x) + # Apply partial rotary embedding to key + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = self.rotary_embedding_layer(key_rot, start_index=start_index) + key = ops.concatenate([key_rot, key_pass], axis=-1) + return key, value + + # Handle caching for key and value + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + key_update, value_update = _compute_key_value(hidden_states) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + key, value = _compute_key_value(hidden_states) + + # Adjust key and value for grouped-query attention (if applicable) + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) + + # Compute attention output + attention_output = self._compute_attention(query, key, value, attention_mask) + attention_output = self._dropout_layer(attention_output, training=training) + attention_output = self._output_dense(attention_output) + + return attention_output, cache if cache is not None else attention_output + + def _masked_softmax(self, attention_scores, attention_mask=None): + if attention_mask is not None: + return self._softmax(attention_scores, attention_mask[:, None, :, :]) + return self._softmax(attention_scores) + + def _compute_attention(self, query, key, value, attention_mask=None): + if has_flash_attention_support() and self.dropout == 0: + if attention_mask is not None: + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_mask = ops.cast(attention_mask, dtype="bool") + return ops.dot_product_attention( + query, + key, + value, + mask=attention_mask, + scale=self._inv_norm_factor, + ) + + attention_scores = ops.einsum(self._dot_product_equation, query, key) + attention_scores = ops.multiply( + attention_scores, + ops.cast(self._inv_norm_factor, self.compute_dtype), + ) + attention_scores = self._masked_softmax(attention_scores, attention_mask) + attention_scores = ops.cast(attention_scores, self.compute_dtype) + attention_output = ops.einsum(self._combine_equation, attention_scores, value) + return attention_output + + def get_config(self): + config = super().get_config() + config.update( + { + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "hidden_dim": self.hidden_dim, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "rotary_percentage": self.rotary_percentage, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "dropout": self.dropout, + } + ) + return config + + From a903bdd54c7cc4569667e1dbc3ffa9e48502484c Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Tue, 18 Mar 2025 22:40:54 +0530 Subject: [PATCH 02/10] stablelm_decoder --- .../src/models/stablelm/stablelm_decoder.py | 219 ++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 keras_hub/src/models/stablelm/stablelm_decoder.py diff --git a/keras_hub/src/models/stablelm/stablelm_decoder.py b/keras_hub/src/models/stablelm/stablelm_decoder.py new file mode 100644 index 0000000000..3f897be14b --- /dev/null +++ b/keras_hub/src/models/stablelm/stablelm_decoder.py @@ -0,0 +1,219 @@ +import keras +from keras import ops + +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_hub.src.utils.keras_utils import clone_initializer +from keras_hub.src.models.stablelm.stablelm_attention import StableLMAttention + +class StableLMTransformerDecoder(keras.layers.Layer): + """StableLM-3B4E1T Transformer decoder layer. + + This layer implements the decoder for StableLM-3B4E1T, a decoder-only transformer + with multi-head self-attention using partial rotary position embeddings (RoPE) + and LayerNorm with learned bias terms. + + Args: + intermediate_dim (int): Hidden size of the feedforward network. + num_query_heads (int): Number of query attention heads (32 for StableLM-3B4E1T). + num_key_value_heads (int): Number of key/value attention heads (32 for StableLM-3B4E1T). + rope_max_wavelength (float, optional): Maximum wavelength for RoPE. Defaults to 10000. + rope_scaling_factor (float, optional): Scaling factor for RoPE. Defaults to 1.0. + rotary_percentage (float, optional): Percentage of head dimensions for RoPE (0.25 for StableLM). + activation (str or callable, optional): Activation for the feedforward network. Defaults to "silu". + layer_norm_epsilon (float, optional): Epsilon for LayerNorm. Defaults to 1e-5. + kernel_initializer (str or initializer, optional): Initializer for dense layers. Defaults to "glorot_uniform". + dropout (float, optional): Dropout rate. Defaults to 0.0. + **kwargs: Additional keyword arguments for the parent class. + """ + + def __init__( + self, + intermediate_dim, + num_query_heads, + num_key_value_heads, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + rotary_percentage=0.25, + activation="silu", + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.intermediate_dim = intermediate_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.rotary_percentage = rotary_percentage + self.activation = keras.activations.get(activation) + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.dropout = dropout + self.supports_masking = True + + def build(self, decoder_sequence_shape): + self.hidden_dim = decoder_sequence_shape[-1] + + # Self-attention layer with partial RoPE + self._self_attention_layer = StableLMAttention( + num_query_heads=self.num_query_heads, + num_key_value_heads=self.num_key_value_heads, + hidden_dim=self.hidden_dim, + rope_max_wavelength=self.rope_max_wavelength, + rope_scaling_factor=self.rope_scaling_factor, + rotary_percentage=self.rotary_percentage, + kernel_initializer=clone_initializer(self.kernel_initializer), + dropout=self.dropout, + dtype=self.dtype_policy, + name="self_attention", + ) + self._self_attention_layer.build(decoder_sequence_shape) + + # LayerNorm for self-attention (with learned bias) + self._self_attention_layernorm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="self_attention_layernorm", + ) + self._self_attention_layernorm.build(decoder_sequence_shape) + + # Dropout for self-attention + self._self_attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="self_attention_dropout", + ) + + # Feedforward layers (gated MLP) + self._feedforward_gate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_gate_dense", + ) + self._feedforward_gate_dense.build(decoder_sequence_shape) + + self._feedforward_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_intermediate_dense", + ) + self._feedforward_intermediate_dense.build(decoder_sequence_shape) + + self._feedforward_output_dense = keras.layers.Dense( + self.hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_output_dense", + ) + self._feedforward_output_dense.build( + self._feedforward_gate_dense.compute_output_shape(decoder_sequence_shape) + ) + + # LayerNorm for feedforward (with learned bias) + self._feedforward_layernorm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="feedforward_layernorm", + ) + self._feedforward_layernorm.build(decoder_sequence_shape) + + self.built = True + + def call( + self, + decoder_sequence, + decoder_padding_mask=None, + decoder_attention_mask=None, + self_attention_cache=None, + self_attention_cache_update_index=None, + training=None, + ): + # Compute the attention mask + self_attention_mask = self._compute_self_attention_mask( + decoder_sequence=decoder_sequence, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + self_attention_cache=self_attention_cache, + self_attention_cache_update_index=self_attention_cache_update_index, + ) + + residual = decoder_sequence + + # Self-attention block + x = self._self_attention_layernorm(decoder_sequence) + x, self_attention_cache = self._self_attention_layer( + hidden_states=x, + attention_mask=self_attention_mask, + cache=self_attention_cache, + cache_update_index=self_attention_cache_update_index, + ) + x = self._self_attention_dropout(x, training=training) + x = x + residual + + residual = x + + # Feedforward block + x = self._feedforward_layernorm(x) + gate_output = self._feedforward_gate_dense(x) + gate_output = self.activation(gate_output) + intermediate_output = self._feedforward_intermediate_dense(x) + x = self._feedforward_output_dense(ops.multiply(intermediate_output, gate_output)) + decoder_output = x + residual + + if self_attention_cache is not None: + return decoder_output, self_attention_cache + return decoder_output + + def _compute_self_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + self_attention_cache, + self_attention_cache_update_index, + ): + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + if self_attention_cache is not None: + input_length = ops.shape(self_attention_cache)[2] + cache_update_index = 0 if self_attention_cache_update_index is None else self_attention_cache_update_index + causal_mask = compute_causal_mask( + batch_size, input_length, output_length, cache_update_index + ) + return ops.minimum(decoder_mask, causal_mask) if decoder_mask is not None else causal_mask + + def compute_output_shape(self, decoder_sequence_shape): + return decoder_sequence_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "intermediate_dim": self.intermediate_dim, + "num_query_heads": self.num_query_heads, + "num_key_value_heads": self.num_key_value_heads, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "rotary_percentage": self.rotary_percentage, + "activation": keras.activations.serialize(self.activation), + "layer_norm_epsilon": self.layer_norm_epsilon, + "kernel_initializer": keras.initializers.serialize(self.kernel_initializer), + "dropout": self.dropout, + } + ) + return config \ No newline at end of file From 8540522a64289656248c23d4637cfaf83f29b355 Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Tue, 18 Mar 2025 22:44:47 +0530 Subject: [PATCH 03/10] stablelm_backbone and stablelm_backbone_test --- .../src/models/stablelm/stablelm_backbone.py | 291 ++++++++++++++++++ .../models/stablelm/stablelm_backbone_test.py | 42 +++ 2 files changed, 333 insertions(+) create mode 100644 keras_hub/src/models/stablelm/stablelm_backbone.py create mode 100644 keras_hub/src/models/stablelm/stablelm_backbone_test.py diff --git a/keras_hub/src/models/stablelm/stablelm_backbone.py b/keras_hub/src/models/stablelm/stablelm_backbone.py new file mode 100644 index 0000000000..70fab17a89 --- /dev/null +++ b/keras_hub/src/models/stablelm/stablelm_backbone.py @@ -0,0 +1,291 @@ + +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ReversibleEmbedding +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.stablelm.stablelm_decoder import StableLMTransformerDecoder + +def _stablelm_kernel_initializer(stddev=0.02): + """Initializer for StableLM kernel weights.""" + return keras.initializers.RandomNormal(stddev=stddev) + +@keras_hub_export("keras_hub.models.StableLMBackbone") +class StableLMBackbone(Backbone): + """ + The StableLM Transformer core architecture with hyperparameters. + + This network implements a Transformer-based decoder network for StableLM-3B4E1T, + as described in the official documentation. It is a decoder-only transformer similar + to LLaMA with modifications including partial rotary position embeddings and + LayerNorm with learned bias terms. It includes the embedding lookups and transformer + layers. + + The default constructor provides a fully customizable, randomly initialized + StableLM model with any number of layers, heads, and embedding dimensions. + + Args: + vocabulary_size (int): The size of the token vocabulary. + num_layers (int): The number of transformer layers (32 for StableLM-3B4E1T). + num_query_heads (int): The number of query attention heads (32 for StableLM-3B4E1T). + hidden_dim (int): The hidden size (2560 for StableLM-3B4E1T). + intermediate_dim (int): The output dimension of the first Dense layer in the + feedforward network. + num_key_value_heads (int): The number of key/value attention heads (32 for + StableLM-3B4E1T). + rope_max_wavelength (int, optional): The maximum wavelength for RoPE. Defaults + to 10000. + rope_scaling_factor (float, optional): The scaling factor for RoPE. Defaults to 1.0. + layer_norm_epsilon (float, optional): Epsilon for LayerNorm. Defaults to 1e-5. + dropout (float, optional): Dropout rate. Defaults to 0.0. + tie_word_embeddings (bool, optional): Whether to tie input and output embeddings. + Defaults to False. + dtype: The dtype to use for computations and weights. + + Example: + ```python + # Randomly initialized StableLM decoder with custom config + model = StableLMBackbone( + vocabulary_size=50257, + num_layers=32, + num_query_heads=32, + hidden_dim=2560, + intermediate_dim=6912, + num_key_value_heads=32, + rotary_percentage=0.25, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + layer_norm_epsilon=1e-5, + dropout=0.0, + dtype="float32", + ) + + # Example input data + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Forward pass + output = model(input_data) + print(output.shape) # Expected: (1, 12, 2560) + """ + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + hidden_dim, + intermediate_dim, + num_key_value_heads, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + rotary_percentage=0.25, + layer_norm_epsilon=1e-5, + dropout=0.0, + dtype=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + embeddings_initializer=_stablelm_kernel_initializer(stddev=0.01), + dtype=dtype, + name="token_embedding", + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = StableLMTransformerDecoder( + intermediate_dim=intermediate_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, + rotary_percentage=rotary_percentage, + activation="silu", # Common activation for modern transformers + layer_norm_epsilon=layer_norm_epsilon, + kernel_initializer=_stablelm_kernel_initializer(stddev=0.02), + dropout=dropout, + dtype=dtype, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = keras.layers.LayerNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="sequence_output_layernorm", + ) + + # === Functional Model === + token_id_input = keras.Input(shape=(None,), dtype="int32", name="token_ids") + padding_mask_input = keras.Input(shape=(None,), dtype="int32", name="padding_mask") + x = self.token_embedding(token_id_input) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, decoder_padding_mask=padding_mask_input) + sequence_output = self.layer_norm(x) + super().__init__( + inputs={"token_ids": token_id_input, "padding_mask": padding_mask_input}, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_key_value_heads = num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + self.rotary_percentage = rotary_percentage + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + + def get_config(self): + """Returns the configuration of the model for serialization.""" + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_key_value_heads": self.num_key_value_heads, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "rotary_percentage": self.rotary_percentage, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + } + ) + return config + + @staticmethod + def get_layout_map( + device_mesh, + model_parallel_dim_name="model", + data_parallel_dim_name="batch", + ): + """Get a `keras.distribution.LayoutMap` for model parallel distribution. + + The returned `LayoutMap` contains the sharding spec for the Llama + backbone weights, so that you can use it to distribute weights across + the accelerators. + + Example: + ``` + # Feel free to change the mesh shape to balance data and model + # parallelism + mesh = keras.distribution.DeviceMesh( + shape=(1, 8), + axis_names=('batch', 'model'), + devices=keras.distribution.list_devices(), + ) + layout_map = StableLMBackbone.get_layout_map( + mesh, + model_parallel_dim_name="model", + ) + + distribution = keras.distribution.ModelParallel( + layout_map=layout_map, + batch_dim_name='batch', + ) + + with distribution.scope(): + stablelm_model = keras_hub.models.StableLMCausalLM.from_preset() + ``` + + To see how the layout map was applied, load the model then run + (for one decoder block): + ``` + embedding_layer = stablelm_model.backbone.get_layer("token_embedding") + decoder_block_1 = stablelm_model.backbone.get_layer('transformer_layer_0') + for variable in embedding_layer.weights + decoder_block_1.weights: + print( + f'{variable.path:<58} {str(variable.shape):<16} ' + f'{str(variable.value.sharding.spec)}' + ) + ``` + + Args: + device_mesh: The `keras.distribution.DeviceMesh` instance for + distribution. + model_parallel_dim_name: The axis name of the device mesh, where + the weights should be partition on. + data_parallel_dim_name: The axis name of the device mesh, where + the data should be partition on. + Return: + `keras.distribution.LayoutMap` that contains the sharding spec + for all the model weights. + """ + # The weight path and shape of the Llama backbone is like below + # token_embedding/embeddings (128256, 2048) + # repeat block for decoder + # transformer_layer_0/self_attention/query/kernel (2048, 32, 64) + # transformer_layer_0/self_attention/key/kernel (2048, 8, 64) + # transformer_layer_0/self_attention/value/kernel (2048, 8, 64) + # transformer_layer_0/self_attention/attention_output/kernel + # (32, 64, 2048) + # transformer_layer_0/self_attention_layernorm/scale (2048,) + # transformer_layer_0/feedforward_intermediate_dense/kernel + # (2048, 8192) + # transformer_layer_0/feedforward_gate_dense/kernel (2048, 8192) + # transformer_layer_0/feedforward_output_dense/kerne (8192, 2048) + # transformer_layer_0/feedforward_layernorm/scale (2048,) + + if not isinstance(device_mesh, keras.distribution.DeviceMesh): + raise ValueError( + "Invalid device_mesh type. Expected " + f"`keras.distribution.Device`, got {type(device_mesh)}" + ) + if model_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{model_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + if data_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{data_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + # Note that it is possible to further config the mesh to be 3D, eg + # (data, seq, model). We leave it as 2D for now for simplicity. + data_dim = data_parallel_dim_name + model_dim = model_parallel_dim_name + # The sharding config is based on the Gemma team training config. + # See https://arxiv.org/abs/2403.08295 + layout_map = keras.distribution.LayoutMap(device_mesh) + layout_map["token_embedding/embeddings"] = (model_dim, data_dim) + layout_map[ + "transformer_layer.*self_attention.*(query|key|value).kernel" + ] = ( + model_dim, + data_dim, + None, + ) + layout_map["transformer_layer.*attention_output.kernel"] = ( + model_dim, + None, + data_dim, + ) + layout_map[ + "transformer_layer.*feedforward_intermediate_dense.kernel" + ] = ( + data_dim, + model_dim, + ) + layout_map["transformer_layer.*feedforward_gate_dense.kernel"] = ( + data_dim, + model_dim, + ) + layout_map["transformer_layer.*feedforward_output_dense.kernel"] = ( + model_dim, + data_dim, + ) + + return layout_map diff --git a/keras_hub/src/models/stablelm/stablelm_backbone_test.py b/keras_hub/src/models/stablelm/stablelm_backbone_test.py new file mode 100644 index 0000000000..1cb6851dac --- /dev/null +++ b/keras_hub/src/models/stablelm/stablelm_backbone_test.py @@ -0,0 +1,42 @@ +import keras +import pytest +from keras import ops + +from keras_hub.src.models.stablelm.stablelm_backbone import StableLMBackbone +from keras_hub.src.tests.test_case import TestCase + + +class StableLMBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_query_heads": 2, + "hidden_dim": 2, + "intermediate_dim": 4, + "num_key_value_heads": 2 + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + """Test that the backbone processes input correctly and outputs the expected shape.""" + self.run_backbone_test( + cls=StableLMBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 2), + ) + + @pytest.mark.large + def test_saved_model(self): + """Test that the model can be saved and loaded successfully.""" + self.run_model_saving_test( + cls=StableLMBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + \ No newline at end of file From a28143e31d3465f1707fe98da3d39fa705ccb952 Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Tue, 18 Mar 2025 22:51:17 +0530 Subject: [PATCH 04/10] stablelm_tokenizer and stablelm_tokenizer_test --- .../src/models/stablelm/stablelm_tokenizer.py | 52 +++++++++++++++++++ .../stablelm/stablelm_tokenizer_test.py | 35 +++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 keras_hub/src/models/stablelm/stablelm_tokenizer.py create mode 100644 keras_hub/src/models/stablelm/stablelm_tokenizer_test.py diff --git a/keras_hub/src/models/stablelm/stablelm_tokenizer.py b/keras_hub/src/models/stablelm/stablelm_tokenizer.py new file mode 100644 index 0000000000..b9bea16f92 --- /dev/null +++ b/keras_hub/src/models/stablelm/stablelm_tokenizer.py @@ -0,0 +1,52 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.stablelm.stablelm_backbone import StableLMBackbone +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_hub_export( + [ + "keras_hub.tokenizers.StableLMTokenizer", + "keras_hub.models.StableLMTokenizer", + ] +) +class StableLMTokenizer(BytePairTokenizer): + """A StableLM tokenizer using Byte-Pair Encoding subword segmentation. + + This tokenizer class tokenizes raw strings into integer sequences and is + based on `keras_hub.tokenizers.BytePairTokenizer`. It mirrors the GPT-NeoX + tokenizer, as specified in the StableLM official documentation, and checks + for all special tokens required by StableLM models. It provides a + `from_preset()` method to automatically download a matching vocabulary for + a StableLM preset. + + If input is a batch of strings (rank > 0), the layer outputs a + `tf.RaggedTensor` where the last dimension is ragged. If input is a scalar + string (rank == 0), the layer outputs a dense `tf.Tensor` with static + shape `[None]`. + + Args: + vocabulary: string or dict, maps tokens to integer IDs. If a string, it + should be the file path to a JSON file containing the vocabulary. + merges: string or list, contains the merge rules. If a string, it should + be the file path to a file with merge rules, where each line contains + merge entities separated by a space. + """ + + backbone_cls = StableLMBackbone + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + # StableLM uses the GPT-NeoX tokenizer, which has "<|endoftext|>" as both + # start and end token. + self._add_special_token("<|endoftext|>", "end_token") + self._add_special_token("<|endoftext|>", "start_token") + self.pad_token_id = 0 + super().__init__( + vocabulary=vocabulary, + merges=merges, + **kwargs, + ) \ No newline at end of file diff --git a/keras_hub/src/models/stablelm/stablelm_tokenizer_test.py b/keras_hub/src/models/stablelm/stablelm_tokenizer_test.py new file mode 100644 index 0000000000..343fbb3f5a --- /dev/null +++ b/keras_hub/src/models/stablelm/stablelm_tokenizer_test.py @@ -0,0 +1,35 @@ +import keras +from keras_hub.src.models.stablelm.stablelm_tokenizer import StableLMTokenizer +from keras_hub.src.tests.test_case import TestCase + +class StableLMTokenizerTest(TestCase): + def setUp(self): + + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port", "<|endoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = [ + "Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e", + "Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt", + "Ġai r", "Ġa i", "pla ne" + ] + self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges} + self.input_data = [ + " airplane at airport<|endoftext|>", + " airplane airport", + ] + + def test_tokenizer_basics(self): + expected_output = [[2, 3, 4, 2, 5, 6],[2, 3, 2, 5]] + + # Run the preprocessing layer test to verify tokenization + self.run_preprocessing_layer_test( + cls=StableLMTokenizer, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=expected_output, + ) + + def test_errors_missing_special_tokens(self): + # Test that an error is raised if "<|endoftext|>" is missing from the vocabulary + with self.assertRaises(ValueError): + StableLMTokenizer(vocabulary=["a", "b", "c"], merges=[]) \ No newline at end of file From 308371988235c49ff79eb3fffaadb57ba5f0817c Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Tue, 18 Mar 2025 22:53:02 +0530 Subject: [PATCH 05/10] stablelm_causal_lm_preprocessor --- .../stablelm_causal_lm_preprocessor.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 keras_hub/src/models/stablelm/stablelm_causal_lm_preprocessor.py diff --git a/keras_hub/src/models/stablelm/stablelm_causal_lm_preprocessor.py b/keras_hub/src/models/stablelm/stablelm_causal_lm_preprocessor.py new file mode 100644 index 0000000000..35cfb57b44 --- /dev/null +++ b/keras_hub/src/models/stablelm/stablelm_causal_lm_preprocessor.py @@ -0,0 +1,65 @@ + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.stablelm.stablelm_backbone import StableLMBackbone +from keras_hub.src.models.stablelm.stablelm_tokenizer import StableLMTokenizer + + +@keras_hub_export("keras_hub.models.StableLMCausalLMPreprocessor") +class StableLMCausalLMPreprocessor(CausalLMPreprocessor): + """StableLM Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.StableLMCausalLM`. By default, it will take in batches of + strings and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token ID in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_hub.models.StableLMCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g., to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_hub.models.StableLMTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor`, or list of Python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset + preprocessor = keras_hub.models.StableLMCausalLMPreprocessor.from_preset( + "stablelm_3b_en" + ) + + # Tokenize and preprocess a single sentence + sentence = tf.constant("Hello, world!") + preprocessor(sentence) + # Same output + preprocessor("Hello, world!") + + # Tokenize a batch of sentences + sentences = tf.constant(["Hello, world!", "StableLM is here."]) + preprocessor(sentences) + # Same output + preprocessor(["Hello, world!", "StableLM is here."]) + + # Map a dataset to preprocess sentences + features = tf.constant(["Text 1", "Text 2"]) + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + """ + backbone_cls = StableLMBackbone + tokenizer_cls = StableLMTokenizer \ No newline at end of file From 61728507206c9599ea0b4babd21fd3a699835e14 Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Tue, 18 Mar 2025 22:54:24 +0530 Subject: [PATCH 06/10] stablelm_causal_lm_preprocessor_test --- .../stablelm_causal_lm_preprocessor_test.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 keras_hub/src/models/stablelm/stablelm_causal_lm_preprocessor_test.py diff --git a/keras_hub/src/models/stablelm/stablelm_causal_lm_preprocessor_test.py b/keras_hub/src/models/stablelm/stablelm_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..60b40f3ae1 --- /dev/null +++ b/keras_hub/src/models/stablelm/stablelm_causal_lm_preprocessor_test.py @@ -0,0 +1,67 @@ +import os +import pytest +from keras_hub.src.models.stablelm.stablelm_causal_lm_preprocessor import StableLMCausalLMPreprocessor +from keras_hub.src.models.stablelm.stablelm_tokenizer import StableLMTokenizer +from keras_hub.src.tests.test_case import TestCase + +class StableLMCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["[PAD]", "<|endoftext|>", "!", "air", "plane", "at", "port"] + self.merges = ["a i", "p l", "n e", "pl a", "po rt"] + self.tokenizer = StableLMTokenizer(vocabulary=self.vocab, merges=self.merges) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ["air plane at port"] + + def test_preprocessor_basics(self): + preprocessor = StableLMCausalLMPreprocessor(**self.init_kwargs) + x, y, sw = preprocessor(self.input_data) + # Expected tokenization: "<|endoftext|> air plane at port" -> [1, 3, 4, 5, 6] + # Padded to sequence_length=8: [1, 3, 4, 5, 6, 0, 0, 0] + self.assertAllEqual(x["token_ids"], [[1, 3, 4, 5, 6, 0, 0, 0]]) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]]) + # Labels are shifted: [3, 4, 5, 6, 0, 0, 0, 0] + self.assertAllEqual(y, [[3, 4, 5, 6, 0, 0, 0, 0]]) + # Sample weights are 1 where labels are non-padding + self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]]) + + def test_no_start_end_token(self): + # Test without start and end tokens, with batch size of 4 + input_data = ["air plane at port"] * 4 + preprocessor = StableLMCausalLMPreprocessor( + tokenizer=self.tokenizer, + sequence_length=8, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + # Tokenization: "air plane at port" -> [3, 4, 5, 6] + # Padded: [3, 4, 5, 6, 0, 0, 0, 0] + self.assertAllEqual(x["token_ids"], [[3, 4, 5, 6, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + # Labels: [4, 5, 6, 0, 0, 0, 0, 0] + self.assertAllEqual(y, [[4, 5, 6, 0, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + # Test preprocessing for generation + preprocessor = StableLMCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess("air plane at port") + # Expected: [1, 3, 4, 5, 6, 0, 0, 0] + self.assertAllEqual(x["token_ids"], [1, 3, 4, 5, 6, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + # Test postprocessing for generation + preprocessor = StableLMCausalLMPreprocessor(**self.init_kwargs) + input_data = { + "token_ids": [1, 3, 4, 5, 6, 0, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0], + } + x = preprocessor.generate_postprocess(input_data) + # Expect detokenized string, may include minor formatting differences due to BPE + self.assertEqual(x, "air plane at port") + + \ No newline at end of file From cb6b83066697d1f5d546fe81bf119737d9ef22b8 Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Tue, 18 Mar 2025 23:16:56 +0530 Subject: [PATCH 07/10] stablelm_causal_lm --- .../src/models/stablelm/stablelm_causal_lm.py | 240 ++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 keras_hub/src/models/stablelm/stablelm_causal_lm.py diff --git a/keras_hub/src/models/stablelm/stablelm_causal_lm.py b/keras_hub/src/models/stablelm/stablelm_causal_lm.py new file mode 100644 index 0000000000..078c96754e --- /dev/null +++ b/keras_hub/src/models/stablelm/stablelm_causal_lm.py @@ -0,0 +1,240 @@ +import keras +from keras import ops +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.stablelm.stablelm_backbone import StableLMBackbone +from keras_hub.src.models.stablelm.stablelm_causal_lm_preprocessor import ( + StableLMCausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export("keras_hub.models.StableLMCausalLM") +class StableLMCausalLM(CausalLM): + """An end-to-end StableLM model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a StableLM model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_hub.samplers` objects to control the generation. By + default, `"top_k"` sampling will be used. + + Args: + backbone: A `keras_hub.models.StableLMBackbone` instance. + preprocessor: A `keras_hub.models.StableLMCausalLMPreprocessor` or `None`. + If `None`, this model will not apply preprocessing, and inputs + should be preprocessed before calling the model. + """ + + backbone_cls = StableLMBackbone + preprocessor_cls = StableLMCausalLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + """Initialize the StableLMCausalLM model.""" + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + self.lm_head = keras.layers.Dense( + self.backbone.vocabulary_size, + use_bias=False, + kernel_initializer=keras.initializers.RandomNormal(stddev=0.02), + name="lm_head", + ) + + # === Functional Model === + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + + def call_with_cache(self, token_ids, cache, cache_update_index): + """Forward pass with caching for autoregressive inference. + + This method enables efficient generation by caching previous key/value + tensors in the attention layers, avoiding recomputation of seen tokens. + + Args: + token_ids: A dense int Tensor with shape `(batch_size, max_length)`. + cache: A dense float Tensor representing the cache of key and value. + cache_update_index: int or int Tensor, the index of current inputs + in the sequence. + + Returns: + A tuple (logits, hidden_states, cache) where: + - `logits`: Language model logits for the input token_ids. + - `hidden_states`: Final hidden representation of the input tokens. + - `cache`: Updated decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + updated_cache = [] + for i, layer in enumerate(self.backbone.transformer_layers): + current_cache = cache[:, i, ...] + x, next_cache = layer( + x, + self_attention_cache=current_cache, + self_attention_cache_update_index=cache_update_index, + ) + updated_cache.append(next_cache) + cache = ops.stack(updated_cache, axis=1) + hidden_states = self.backbone.layer_norm(x) + logits = self.lm_head(hidden_states) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build and seed an empty cache for use with `call_with_cache()`. + + Args: + token_ids: A dense int Tensor with shape `(batch_size, max_length)`. + + Returns: + A tuple (hidden_states, cache) with the initial hidden states and + seeded cache. + """ + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_heads = self.backbone.num_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_heads + shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache with an initial forward pass + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step(self, inputs, stop_token_ids=None): + """A compilable generation function for a single batch of inputs. + + Args: + inputs: A dictionary with keys `"token_ids"` and `"padding_mask"`. + stop_token_ids: Tuple of token IDs to stop generation on. If all + sequences produce a stop token, generation halts. + + Returns: + A dictionary with updated `"token_ids"` and `"padding_mask"`. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + hidden_states, cache = self._build_cache(token_ids) + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + index = ops.min(row_lengths) + + def next(prompt, cache, index): + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, cache, cache_update_index + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + if stop_token_ids is not None: + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + end_locations = ops.cast(end_locations, "int32") + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + padding_mask = ops.ones_like(token_ids, dtype="bool") + + return {"token_ids": token_ids, "padding_mask": padding_mask} + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + This method computes scores for a sequence of token IDs, returning either + logits or per-token loss, depending on the `scoring_mode`. It’s useful for + evaluating model performance or conducting interpretability research. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score, typically from a `generate()` call. + padding_mask: A [batch_size, num_tokens] tensor indicating + valid tokens. Defaults to all ones if not provided. + scoring_mode: Either "logits" or "loss", specifying the type of + scores to return. + layer_intercept_fn: Optional function to modify activations at each + layer, taking activations and layer index as inputs. Useful for + custom computations (e.g., interpretability). Must return a + [batch_size, num_tokens, hidden_dims] tensor. + target_ids: A [batch_size, num_tokens] tensor of true token IDs, + required for "loss" mode to compute the loss. + + Raises: + ValueError: If `scoring_mode` is invalid or `target_ids` is missing + in "loss" mode. + + Returns: + - In "logits" mode: [batch_size, num_tokens, vocab_size] tensor + of logits. + - In "loss" mode: [batch_size, num_tokens] tensor of per-token + loss. + """ + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape, dtype="bool") + + if layer_intercept_fn is None: + def default_layer_intercept_fn(x, unused_i): + return x + layer_intercept_fn = default_layer_intercept_fn + + # Forward pass through the model + x = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(x, -1) # Apply to embeddings (index -1) + + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, decoder_padding_mask=padding_mask) + x = layer_intercept_fn(x, i) # Apply to each transformer layer + + x = self.backbone.layer_norm(x) + logits = self.lm_head(x) + + if scoring_mode == "logits": + return logits + + # Compute per-token loss if scoring_mode is "loss" + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss \ No newline at end of file From 5867ad06597fba3b0e6cf9566ab2da64b7fe0785 Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Tue, 18 Mar 2025 23:31:31 +0530 Subject: [PATCH 08/10] stablelm_causal_lm_test --- .../stablelm/stablelm_causal_lm_test.py | 175 ++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 keras_hub/src/models/stablelm/stablelm_causal_lm_test.py diff --git a/keras_hub/src/models/stablelm/stablelm_causal_lm_test.py b/keras_hub/src/models/stablelm/stablelm_causal_lm_test.py new file mode 100644 index 0000000000..e41db9f1d6 --- /dev/null +++ b/keras_hub/src/models/stablelm/stablelm_causal_lm_test.py @@ -0,0 +1,175 @@ +import os +from unittest.mock import patch + +import pytest +from keras import ops + +from keras_hub.src.models.stablelm.stablelm_backbone import StableLMBackbone +from keras_hub.src.models.stablelm.stablelm_causal_lm import StableLMCausalLM +from keras_hub.src.models.stablelm.stablelm_causal_lm_preprocessor import ( + StableLMCausalLMPreprocessor, +) +from keras_hub.src.models.stablelm.stablelm_tokenizer import StableLMTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class StableLMCausalLMTest(TestCase): + def setUp(self): + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port", "<|endoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = [ + "Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e", + "Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt", + "Ġai r", "Ġa i", "pla ne" + ] + + self.preprocessor = StableLMCausalLMPreprocessor( + tokenizer=StableLMTokenizer(vocabulary=self.vocab, merges=self.merges), + sequence_length=8, + ) + + # Config + self.backbone = StableLMBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_query_heads=2, + num_key_value_heads=2, + hidden_dim=4, + intermediate_dim=8, + ) + + # Initialization kwargs for the causal LM. + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + + # Training data for testing. + self.train_data = ([" airplane at airport", " airplane at airport"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=StableLMCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, 7), + ) + + def test_generate(self): + causal_lm = StableLMCausalLM(**self.init_kwargs) + # Test string input. + prompt = " airplane at airport" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Test integer tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) + # Assert prompt is preserved in output token IDs. + self.assertAllEqual( + outputs["token_ids"][:, :5], + prompt_ids["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + prompt_ids["padding_mask"][:, :5], + ) + + def test_early_stopping(self): + causal_lm = StableLMCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify logits to favor end_token_id for early stopping.""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = [" airplane at airport", " airplane"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = StableLMCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate(" airplane at airport") + first_fn = causal_lm.generate_function + causal_lm.generate(" airplane at airport") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=StableLMCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_score_logits(self): + prompts = [" airplane at airport", " airplane at airport"] + causal_lm = StableLMCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8, 7) + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess(prompts) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + ) + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_loss(self): + prompts = [" airplane at airport", " airplane at airport"] + causal_lm = StableLMCausalLM(**self.init_kwargs) + expected_score_shape = (2, 7) + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess(prompts) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + batch_size = ops.shape(token_ids)[0] + target_ids = ops.slice(token_ids, [0, 1], [batch_size, -1]) + + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="loss", + target_ids=target_ids, + ) + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_layer_intercept_fn_exfiltration(self): + prompts = [" airplane at airport", " airplane at airport"] + causal_lm = StableLMCausalLM(**self.init_kwargs) + expected_embedded_shape = (2, 8, 4) + expected_score_shape = (2, 8, 7) + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess(prompts) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + embedded_prompts = None + + def layer_intercept_fn_for_testing(x, i): + nonlocal embedded_prompts + if i == -1: + embedded_prompts = x + else: + self.assertEqual(ops.shape(x), expected_embedded_shape) + return x + + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + layer_intercept_fn=layer_intercept_fn_for_testing, + ) + self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape) + self.assertEqual(ops.shape(scores), expected_score_shape) \ No newline at end of file From 6fd200b41d15a42bd45c97f3641a9c0bef45dfc8 Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Tue, 18 Mar 2025 23:32:57 +0530 Subject: [PATCH 09/10] initialization --- keras_hub/src/models/stablelm/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 keras_hub/src/models/stablelm/__init__.py diff --git a/keras_hub/src/models/stablelm/__init__.py b/keras_hub/src/models/stablelm/__init__.py new file mode 100644 index 0000000000..6b9d5658c0 --- /dev/null +++ b/keras_hub/src/models/stablelm/__init__.py @@ -0,0 +1 @@ +from keras_hub.src.models.stablelm.stablelm_backbone import StableLMBackbone \ No newline at end of file From 5ce12a0823eb7afe80d4cbe928bfc57945fa25fb Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Mon, 28 Apr 2025 20:39:30 +0530 Subject: [PATCH 10/10] Corrected test cases and added conversion checkpoints --- .../src/models/stablelm/stablelm_attention.py | 71 +++-- .../src/models/stablelm/stablelm_backbone.py | 98 ++++--- .../models/stablelm/stablelm_backbone_test.py | 3 +- .../src/models/stablelm/stablelm_causal_lm.py | 76 +++-- .../stablelm_causal_lm_preprocessor_test.py | 86 +++--- .../stablelm/stablelm_causal_lm_test.py | 68 +---- .../src/models/stablelm/stablelm_decoder.py | 66 +++-- .../src/models/stablelm/stablelm_presets.py | 10 + .../src/models/stablelm/stablelm_tokenizer.py | 8 +- .../stablelm/stablelm_tokenizer_test.py | 9 +- .../convert_stablelm_checkpoints.py | 262 ++++++++++++++++++ 11 files changed, 510 insertions(+), 247 deletions(-) create mode 100644 keras_hub/src/models/stablelm/stablelm_presets.py create mode 100644 tools/checkpoint_conversion/convert_stablelm_checkpoints.py diff --git a/keras_hub/src/models/stablelm/stablelm_attention.py b/keras_hub/src/models/stablelm/stablelm_attention.py index 1c9bef344b..fe3fc038b6 100644 --- a/keras_hub/src/models/stablelm/stablelm_attention.py +++ b/keras_hub/src/models/stablelm/stablelm_attention.py @@ -1,4 +1,5 @@ import math + import keras from keras import ops @@ -14,18 +15,23 @@ class StableLMAttention(keras.layers.Layer): multi-head self-attention with partial rotary position embeddings applied to a fraction of the head dimensions, as specified by `rotary_percentage`. It is adapted from the LlamaAttention layer with modifications to align - with StableLM's official configuration(https://github.com/Stability-AI/StableLM/blob/main/configs/stablelm-3b-4e1t.yml). + with StableLM's official configuration. Args: - num_query_heads (int): Number of attention heads for queries. - num_key_value_heads (int): Number of attention heads for keys and values. - hidden_dim (int): Hidden dimension of the input (e.g., 2560 for StableLM-3B4E1T). - rope_max_wavelength (float): Maximum wavelength for rotary embeddings (default: 10000). - rope_scaling_factor (float): Scaling factor for rotary embeddings (default: 1.0). - rotary_percentage (float): Percentage of head dimensions to apply rotary embeddings (default: 0.25). - kernel_initializer (str or initializer): Initializer for dense layer kernels (default: "glorot_uniform"). - dropout (float): Dropout rate for attention scores (default: 0.0). - **kwargs: Additional keyword arguments passed to the parent class. + num_query_heads: int. Number of attention heads for queries. + num_key_value_heads: int. Number of attention heads for keys and + values. + hidden_dim: int. Hidden dimension of the input (e.g., 2560 for + StableLM-3B4E1T). + rope_max_wavelength: float. Maximum wavelength for rotary embeddings + (default: 10000). + rope_scaling_factor: float. Scaling factor for rotary embeddings + (default: 1.0). + rotary_percentage: float. Percentage of head dimensions to apply + rotary embeddings (default: 0.25). + kernel_initializer: str or initializer. Initializer for dense layer + kernels (default: "glorot_uniform"). + dropout: float. Dropout rate for attention scores (default: 0.0). """ def __init__( @@ -120,7 +126,7 @@ def build(self, inputs_shape): self._dot_product_equation = "bquh,bkuh->buqk" self._combine_equation = "buqk,bkuh->bquh" - self.built = True + super().build(inputs_shape) def call( self, @@ -130,22 +136,25 @@ def call( cache_update_index=None, training=None, ): - start_index = cache_update_index if cache_update_index is not None else 0 - - # Compute query and apply partial rotary embedding + start_index = ( + cache_update_index if cache_update_index is not None else 0 + ) query = self._query_dense(hidden_states) query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] - query_rot = self.rotary_embedding_layer(query_rot, start_index=start_index) + query_rot = self.rotary_embedding_layer( + query_rot, start_index=start_index + ) query = ops.concatenate([query_rot, query_pass], axis=-1) def _compute_key_value(x): key = self._key_dense(x) value = self._value_dense(x) - # Apply partial rotary embedding to key key_rot = key[..., :self.rotary_dim] key_pass = key[..., self.rotary_dim:] - key_rot = self.rotary_embedding_layer(key_rot, start_index=start_index) + key_rot = self.rotary_embedding_layer( + key_rot, start_index=start_index + ) key = ops.concatenate([key_rot, key_pass], axis=-1) return key, value @@ -171,20 +180,27 @@ def _compute_key_value(x): ) key, value = _compute_key_value(hidden_states) - # Adjust key and value for grouped-query attention (if applicable) key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) - # Compute attention output - attention_output = self._compute_attention(query, key, value, attention_mask) - attention_output = self._dropout_layer(attention_output, training=training) + attention_output = self._compute_attention( + query, key, value, attention_mask + ) + attention_output = self._dropout_layer( + attention_output, training=training + ) attention_output = self._output_dense(attention_output) - return attention_output, cache if cache is not None else attention_output + return ( + attention_output, + cache if cache is not None else attention_output, + ) def _masked_softmax(self, attention_scores, attention_mask=None): if attention_mask is not None: - return self._softmax(attention_scores, attention_mask[:, None, :, :]) + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) return self._softmax(attention_scores) def _compute_attention(self, query, key, value, attention_mask=None): @@ -199,15 +215,18 @@ def _compute_attention(self, query, key, value, attention_mask=None): mask=attention_mask, scale=self._inv_norm_factor, ) - attention_scores = ops.einsum(self._dot_product_equation, query, key) attention_scores = ops.multiply( attention_scores, ops.cast(self._inv_norm_factor, self.compute_dtype), ) - attention_scores = self._masked_softmax(attention_scores, attention_mask) + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) attention_scores = ops.cast(attention_scores, self.compute_dtype) - attention_output = ops.einsum(self._combine_equation, attention_scores, value) + attention_output = ops.einsum( + self._combine_equation, attention_scores, value + ) return attention_output def get_config(self): diff --git a/keras_hub/src/models/stablelm/stablelm_backbone.py b/keras_hub/src/models/stablelm/stablelm_backbone.py index 70fab17a89..2d4e60dc4f 100644 --- a/keras_hub/src/models/stablelm/stablelm_backbone.py +++ b/keras_hub/src/models/stablelm/stablelm_backbone.py @@ -1,11 +1,15 @@ import keras -from keras import ops from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.layers.modeling.reversible_embedding import ReversibleEmbedding +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) from keras_hub.src.models.backbone import Backbone -from keras_hub.src.models.stablelm.stablelm_decoder import StableLMTransformerDecoder +from keras_hub.src.models.stablelm.stablelm_decoder import ( + StableLMTransformerDecoder, +) + def _stablelm_kernel_initializer(stddev=0.02): """Initializer for StableLM kernel weights.""" @@ -13,38 +17,54 @@ def _stablelm_kernel_initializer(stddev=0.02): @keras_hub_export("keras_hub.models.StableLMBackbone") class StableLMBackbone(Backbone): - """ - The StableLM Transformer core architecture with hyperparameters. + """The StableLM Transformer core architecture with hyperparameters. - This network implements a Transformer-based decoder network for StableLM-3B4E1T, - as described in the official documentation. It is a decoder-only transformer similar - to LLaMA with modifications including partial rotary position embeddings and - LayerNorm with learned bias terms. It includes the embedding lookups and transformer - layers. + This network implements a Transformer-based decoder network for + StableLM-3B4E1T, as described in the official documentation. It is a + decoder-only transformer similar to LLaMA with modifications including + partial rotary position embeddings and LayerNorm with learned bias terms. + It includes the embedding lookups and transformer layers. The default constructor provides a fully customizable, randomly initialized StableLM model with any number of layers, heads, and embedding dimensions. Args: - vocabulary_size (int): The size of the token vocabulary. - num_layers (int): The number of transformer layers (32 for StableLM-3B4E1T). - num_query_heads (int): The number of query attention heads (32 for StableLM-3B4E1T). - hidden_dim (int): The hidden size (2560 for StableLM-3B4E1T). - intermediate_dim (int): The output dimension of the first Dense layer in the - feedforward network. - num_key_value_heads (int): The number of key/value attention heads (32 for + vocabulary_size: int. The size of the token vocabulary. + num_layers: int. The number of transformer layers (32 for + StableLM-3B4E1T). + num_query_heads: int. The number of query attention heads (32 for StableLM-3B4E1T). - rope_max_wavelength (int, optional): The maximum wavelength for RoPE. Defaults + hidden_dim: int. The hidden size (2560 for StableLM-3B4E1T). + intermediate_dim: int. The output dimension of the first Dense layer + in the feedforward network. + num_key_value_heads: int. The number of key/value attention heads + (32 for StableLM-3B4E1T). + rope_max_wavelength: int. The maximum wavelength for RoPE. Defaults to 10000. - rope_scaling_factor (float, optional): The scaling factor for RoPE. Defaults to 1.0. - layer_norm_epsilon (float, optional): Epsilon for LayerNorm. Defaults to 1e-5. - dropout (float, optional): Dropout rate. Defaults to 0.0. - tie_word_embeddings (bool, optional): Whether to tie input and output embeddings. - Defaults to False. + rope_scaling_factor: float. The scaling factor for RoPE. Defaults + to 1.0. + layer_norm_epsilon: float. Epsilon for LayerNorm. Defaults to 1e-5. + dropout: float. Dropout rate. Defaults to 0.0. + tie_word_embeddings: bool, optional. Whether to tie input and output + embeddings. Defaults to False. dtype: The dtype to use for computations and weights. - Example: + Examples: + ```python + # Load a pretrained StableLM backbone. + model = keras_hub.models.StableLMBackbone.from_preset("stablelm_3b_4e1t_en") + + # Example input data + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Forward pass + output = model(input_data) + print(output.shape) # Expected: (1, 12, 2560) + # Randomly initialized StableLM decoder with custom config model = StableLMBackbone( vocabulary_size=50257, @@ -70,6 +90,7 @@ class StableLMBackbone(Backbone): # Forward pass output = model(input_data) print(output.shape) # Expected: (1, 12, 2560) + ``` """ def __init__( self, @@ -119,14 +140,21 @@ def __init__( ) # === Functional Model === - token_id_input = keras.Input(shape=(None,), dtype="int32", name="token_ids") - padding_mask_input = keras.Input(shape=(None,), dtype="int32", name="padding_mask") + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) x = self.token_embedding(token_id_input) for transformer_layer in self.transformer_layers: x = transformer_layer(x, decoder_padding_mask=padding_mask_input) sequence_output = self.layer_norm(x) super().__init__( - inputs={"token_ids": token_id_input, "padding_mask": padding_mask_input}, + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input + }, outputs=sequence_output, dtype=dtype, **kwargs, @@ -204,7 +232,7 @@ def get_layout_map( (for one decoder block): ``` embedding_layer = stablelm_model.backbone.get_layer("token_embedding") - decoder_block_1 = stablelm_model.backbone.get_layer('transformer_layer_0') + decoder_block_1 = stablelm_model.backbone.get_layer('transformer_layer_0 for variable in embedding_layer.weights + decoder_block_1.weights: print( f'{variable.path:<58} {str(variable.shape):<16} ' @@ -223,20 +251,6 @@ def get_layout_map( `keras.distribution.LayoutMap` that contains the sharding spec for all the model weights. """ - # The weight path and shape of the Llama backbone is like below - # token_embedding/embeddings (128256, 2048) - # repeat block for decoder - # transformer_layer_0/self_attention/query/kernel (2048, 32, 64) - # transformer_layer_0/self_attention/key/kernel (2048, 8, 64) - # transformer_layer_0/self_attention/value/kernel (2048, 8, 64) - # transformer_layer_0/self_attention/attention_output/kernel - # (32, 64, 2048) - # transformer_layer_0/self_attention_layernorm/scale (2048,) - # transformer_layer_0/feedforward_intermediate_dense/kernel - # (2048, 8192) - # transformer_layer_0/feedforward_gate_dense/kernel (2048, 8192) - # transformer_layer_0/feedforward_output_dense/kerne (8192, 2048) - # transformer_layer_0/feedforward_layernorm/scale (2048,) if not isinstance(device_mesh, keras.distribution.DeviceMesh): raise ValueError( diff --git a/keras_hub/src/models/stablelm/stablelm_backbone_test.py b/keras_hub/src/models/stablelm/stablelm_backbone_test.py index 1cb6851dac..80fbc325c4 100644 --- a/keras_hub/src/models/stablelm/stablelm_backbone_test.py +++ b/keras_hub/src/models/stablelm/stablelm_backbone_test.py @@ -1,4 +1,3 @@ -import keras import pytest from keras import ops @@ -22,7 +21,7 @@ def setUp(self): } def test_backbone_basics(self): - """Test that the backbone processes input correctly and outputs the expected shape.""" + """Test that the backbone processes with expected shape.""" self.run_backbone_test( cls=StableLMBackbone, init_kwargs=self.init_kwargs, diff --git a/keras_hub/src/models/stablelm/stablelm_causal_lm.py b/keras_hub/src/models/stablelm/stablelm_causal_lm.py index 078c96754e..6145577af0 100644 --- a/keras_hub/src/models/stablelm/stablelm_causal_lm.py +++ b/keras_hub/src/models/stablelm/stablelm_causal_lm.py @@ -1,5 +1,6 @@ import keras from keras import ops + from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.causal_lm import CausalLM from keras_hub.src.models.stablelm.stablelm_backbone import StableLMBackbone @@ -19,19 +20,14 @@ class StableLMCausalLM(CausalLM): the data used for training. This task can be used for pre-training or fine-tuning a StableLM model, simply by calling `fit()`. - This model has a `generate()` method, which generates text based on a - prompt. The generation strategy used is controlled by an additional - `sampler` argument on `compile()`. You can recompile the model with - different `keras_hub.samplers` objects to control the generation. By - default, `"top_k"` sampling will be used. - Args: backbone: A `keras_hub.models.StableLMBackbone` instance. - preprocessor: A `keras_hub.models.StableLMCausalLMPreprocessor` or `None`. - If `None`, this model will not apply preprocessing, and inputs - should be preprocessed before calling the model. + preprocessor: A `keras_hub.models.StableLMCausalLMPreprocessor` or + `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. """ + backbone_cls = StableLMBackbone preprocessor_cls = StableLMCausalLMPreprocessor @@ -83,27 +79,24 @@ def call_with_cache(self, token_ids, cache, cache_update_index): updated_cache.append(next_cache) cache = ops.stack(updated_cache, axis=1) hidden_states = self.backbone.layer_norm(x) - logits = self.lm_head(hidden_states) + logits = self.backbone.token_embedding(hidden_states, reverse=True) return logits, hidden_states, cache def _build_cache(self, token_ids): - """Build and seed an empty cache for use with `call_with_cache()`. - - Args: - token_ids: A dense int Tensor with shape `(batch_size, max_length)`. - - Returns: - A tuple (hidden_states, cache) with the initial hidden states and - seeded cache. - """ batch_size = ops.shape(token_ids)[0] max_length = ops.shape(token_ids)[1] num_layers = self.backbone.num_layers - num_heads = self.backbone.num_heads - head_dim = self.backbone.hidden_dim // self.backbone.num_heads - shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim] + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads + shape = [ + batch_size, + num_layers, + 2, + max_length, + num_key_value_heads, + head_dim, + ] cache = ops.zeros(shape, dtype=self.compute_dtype) - # Seed the cache with an initial forward pass _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) return hidden_states, cache @@ -170,33 +163,28 @@ def score( ): """Score a generation represented by the provided token ids. - This method computes scores for a sequence of token IDs, returning either - logits or per-token loss, depending on the `scoring_mode`. It’s useful for - evaluating model performance or conducting interpretability research. + This method computes scores for a sequence of token IDs, returning + either logits or per-token loss, depending on the `scoring_mode`. + It’s useful for evaluating model performance or conducting + interpretability research. Args: - token_ids: A [batch_size, num_tokens] tensor containing tokens - to score, typically from a `generate()` call. + token_ids: A [batch_size, num_tokens] tensor containing + tokens to score. padding_mask: A [batch_size, num_tokens] tensor indicating - valid tokens. Defaults to all ones if not provided. + valid tokens. scoring_mode: Either "logits" or "loss", specifying the type of scores to return. - layer_intercept_fn: Optional function to modify activations at each - layer, taking activations and layer index as inputs. Useful for - custom computations (e.g., interpretability). Must return a - [batch_size, num_tokens, hidden_dims] tensor. - target_ids: A [batch_size, num_tokens] tensor of true token IDs, - required for "loss" mode to compute the loss. - - Raises: - ValueError: If `scoring_mode` is invalid or `target_ids` is missing - in "loss" mode. + layer_intercept_fn: Optional function to modify activations at + each layer. + target_ids: A [batch_size, num_tokens] tensor of true token + IDs, required for "loss" mode. Returns: - - In "logits" mode: [batch_size, num_tokens, vocab_size] tensor - of logits. - - In "loss" mode: [batch_size, num_tokens] tensor of per-token - loss. + - In "logits" mode: [batch_size, num_tokens, vocab_size] + tensor of logits. + - In "loss" mode: [batch_size, num_tokens] tensor of + per-token loss. """ if scoring_mode not in ("logits", "loss"): raise ValueError( @@ -227,7 +215,7 @@ def default_layer_intercept_fn(x, unused_i): x = layer_intercept_fn(x, i) # Apply to each transformer layer x = self.backbone.layer_norm(x) - logits = self.lm_head(x) + logits = self.backbone.token_embedding(x, reverse=True) if scoring_mode == "logits": return logits diff --git a/keras_hub/src/models/stablelm/stablelm_causal_lm_preprocessor_test.py b/keras_hub/src/models/stablelm/stablelm_causal_lm_preprocessor_test.py index 60b40f3ae1..0254016346 100644 --- a/keras_hub/src/models/stablelm/stablelm_causal_lm_preprocessor_test.py +++ b/keras_hub/src/models/stablelm/stablelm_causal_lm_preprocessor_test.py @@ -1,67 +1,73 @@ -import os -import pytest -from keras_hub.src.models.stablelm.stablelm_causal_lm_preprocessor import StableLMCausalLMPreprocessor +from keras import ops + +from keras_hub.src.models.stablelm.stablelm_causal_lm_preprocessor import ( + StableLMCausalLMPreprocessor, +) from keras_hub.src.models.stablelm.stablelm_tokenizer import StableLMTokenizer from keras_hub.src.tests.test_case import TestCase + class StableLMCausalLMPreprocessorTest(TestCase): def setUp(self): - self.vocab = ["[PAD]", "<|endoftext|>", "!", "air", "plane", "at", "port"] - self.merges = ["a i", "p l", "n e", "pl a", "po rt"] - self.tokenizer = StableLMTokenizer(vocabulary=self.vocab, merges=self.merges) + self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"] + self.vocab += ["<|endoftext|>"] + self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) + self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"] + self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"] + self.merges += ["Ġai r", "Ġa i", "pla ne"] + self.tokenizer = StableLMTokenizer( + vocabulary=self.vocab, + merges=self.merges, + ) self.init_kwargs = { "tokenizer": self.tokenizer, "sequence_length": 8, } - self.input_data = ["air plane at port"] + self.input_data = ["airplane at airport"] def test_preprocessor_basics(self): - preprocessor = StableLMCausalLMPreprocessor(**self.init_kwargs) - x, y, sw = preprocessor(self.input_data) - # Expected tokenization: "<|endoftext|> air plane at port" -> [1, 3, 4, 5, 6] - # Padded to sequence_length=8: [1, 3, 4, 5, 6, 0, 0, 0] - self.assertAllEqual(x["token_ids"], [[1, 3, 4, 5, 6, 0, 0, 0]]) - self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]]) - # Labels are shifted: [3, 4, 5, 6, 0, 0, 0, 0] - self.assertAllEqual(y, [[3, 4, 5, 6, 0, 0, 0, 0]]) - # Sample weights are 1 where labels are non-padding - self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]]) - + self.run_preprocessor_test( + cls=StableLMCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[6, 1, 3, 4, 2, 5, 6, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]], + }, + [[1, 3, 4, 2, 5, 6, 0, 0]], # Pass through labels. + [[1, 1, 1, 1, 1, 1, 0, 0]], # Pass through sample_weights. + ), + ) + def test_no_start_end_token(self): - # Test without start and end tokens, with batch size of 4 - input_data = ["air plane at port"] * 4 + input_data = ["airplane at airport"] * 4 + preprocessor = StableLMCausalLMPreprocessor( - tokenizer=self.tokenizer, - sequence_length=8, + **self.init_kwargs, add_start_token=False, add_end_token=False, ) x, y, sw = preprocessor(input_data) - # Tokenization: "air plane at port" -> [3, 4, 5, 6] - # Padded: [3, 4, 5, 6, 0, 0, 0, 0] - self.assertAllEqual(x["token_ids"], [[3, 4, 5, 6, 0, 0, 0, 0]] * 4) - self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) - # Labels: [4, 5, 6, 0, 0, 0, 0, 0] - self.assertAllEqual(y, [[4, 5, 6, 0, 0, 0, 0, 0]] * 4) - self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[3, 4, 2, 5, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) def test_generate_preprocess(self): - # Test preprocessing for generation + input_data = "airplane at airport" preprocessor = StableLMCausalLMPreprocessor(**self.init_kwargs) - x = preprocessor.generate_preprocess("air plane at port") - # Expected: [1, 3, 4, 5, 6, 0, 0, 0] - self.assertAllEqual(x["token_ids"], [1, 3, 4, 5, 6, 0, 0, 0]) - self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0]) def test_generate_postprocess(self): - # Test postprocessing for generation - preprocessor = StableLMCausalLMPreprocessor(**self.init_kwargs) input_data = { - "token_ids": [1, 3, 4, 5, 6, 0, 0, 0], - "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0], + "token_ids": ops.array([6, 1, 3, 4, 2, 5, 0, 0]), + "padding_mask": ops.array([1, 1, 1, 1, 1, 1, 0, 0], dtype="bool"), } + preprocessor = StableLMCausalLMPreprocessor(**self.init_kwargs) x = preprocessor.generate_postprocess(input_data) - # Expect detokenized string, may include minor formatting differences due to BPE - self.assertEqual(x, "air plane at port") + self.assertAllEqual(x, "airplane at airport") \ No newline at end of file diff --git a/keras_hub/src/models/stablelm/stablelm_causal_lm_test.py b/keras_hub/src/models/stablelm/stablelm_causal_lm_test.py index e41db9f1d6..c431c54681 100644 --- a/keras_hub/src/models/stablelm/stablelm_causal_lm_test.py +++ b/keras_hub/src/models/stablelm/stablelm_causal_lm_test.py @@ -1,4 +1,3 @@ -import os from unittest.mock import patch import pytest @@ -15,7 +14,9 @@ class StableLMCausalLMTest(TestCase): def setUp(self): - self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port", "<|endoftext|>"] + self.vocab = [ + "!", "air", "Ġair", "plane", "Ġat", "port", "<|endoftext|>" + ] self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) self.merges = [ "Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e", @@ -24,7 +25,9 @@ def setUp(self): ] self.preprocessor = StableLMCausalLMPreprocessor( - tokenizer=StableLMTokenizer(vocabulary=self.vocab, merges=self.merges), + tokenizer=StableLMTokenizer( + vocabulary=self.vocab, merges=self.merges + ), sequence_length=8, ) @@ -114,62 +117,3 @@ def test_saved_model(self): init_kwargs=self.init_kwargs, input_data=self.input_data, ) - - def test_score_logits(self): - prompts = [" airplane at airport", " airplane at airport"] - causal_lm = StableLMCausalLM(**self.init_kwargs) - expected_score_shape = (2, 8, 7) - preprocessed_prompts = causal_lm.preprocessor.generate_preprocess(prompts) - token_ids = preprocessed_prompts["token_ids"] - padding_mask = preprocessed_prompts["padding_mask"] - scores = causal_lm.score( - token_ids=token_ids, - padding_mask=padding_mask, - scoring_mode="logits", - ) - self.assertEqual(ops.shape(scores), expected_score_shape) - - def test_score_loss(self): - prompts = [" airplane at airport", " airplane at airport"] - causal_lm = StableLMCausalLM(**self.init_kwargs) - expected_score_shape = (2, 7) - preprocessed_prompts = causal_lm.preprocessor.generate_preprocess(prompts) - token_ids = preprocessed_prompts["token_ids"] - padding_mask = preprocessed_prompts["padding_mask"] - batch_size = ops.shape(token_ids)[0] - target_ids = ops.slice(token_ids, [0, 1], [batch_size, -1]) - - scores = causal_lm.score( - token_ids=token_ids, - padding_mask=padding_mask, - scoring_mode="loss", - target_ids=target_ids, - ) - self.assertEqual(ops.shape(scores), expected_score_shape) - - def test_score_layer_intercept_fn_exfiltration(self): - prompts = [" airplane at airport", " airplane at airport"] - causal_lm = StableLMCausalLM(**self.init_kwargs) - expected_embedded_shape = (2, 8, 4) - expected_score_shape = (2, 8, 7) - preprocessed_prompts = causal_lm.preprocessor.generate_preprocess(prompts) - token_ids = preprocessed_prompts["token_ids"] - padding_mask = preprocessed_prompts["padding_mask"] - embedded_prompts = None - - def layer_intercept_fn_for_testing(x, i): - nonlocal embedded_prompts - if i == -1: - embedded_prompts = x - else: - self.assertEqual(ops.shape(x), expected_embedded_shape) - return x - - scores = causal_lm.score( - token_ids=token_ids, - padding_mask=padding_mask, - scoring_mode="logits", - layer_intercept_fn=layer_intercept_fn_for_testing, - ) - self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape) - self.assertEqual(ops.shape(scores), expected_score_shape) \ No newline at end of file diff --git a/keras_hub/src/models/stablelm/stablelm_decoder.py b/keras_hub/src/models/stablelm/stablelm_decoder.py index 3f897be14b..574dbb57bb 100644 --- a/keras_hub/src/models/stablelm/stablelm_decoder.py +++ b/keras_hub/src/models/stablelm/stablelm_decoder.py @@ -7,28 +7,33 @@ from keras_hub.src.layers.modeling.transformer_layer_utils import ( merge_padding_and_attention_mask, ) +from keras_hub.src.models.stablelm.stablelm_attention import StableLMAttention from keras_hub.src.utils.keras_utils import clone_initializer -from keras_hub.src.models.stablelm.stablelm_attention import StableLMAttention + class StableLMTransformerDecoder(keras.layers.Layer): """StableLM-3B4E1T Transformer decoder layer. - This layer implements the decoder for StableLM-3B4E1T, a decoder-only transformer - with multi-head self-attention using partial rotary position embeddings (RoPE) - and LayerNorm with learned bias terms. + This layer implements the decoder for StableLM-3B4E1T, a decoder-only + transformer with multi-head self-attention using partial rotary position + embeddings (RoPE) and LayerNorm with learned bias terms. Args: - intermediate_dim (int): Hidden size of the feedforward network. - num_query_heads (int): Number of query attention heads (32 for StableLM-3B4E1T). - num_key_value_heads (int): Number of key/value attention heads (32 for StableLM-3B4E1T). - rope_max_wavelength (float, optional): Maximum wavelength for RoPE. Defaults to 10000. - rope_scaling_factor (float, optional): Scaling factor for RoPE. Defaults to 1.0. - rotary_percentage (float, optional): Percentage of head dimensions for RoPE (0.25 for StableLM). - activation (str or callable, optional): Activation for the feedforward network. Defaults to "silu". - layer_norm_epsilon (float, optional): Epsilon for LayerNorm. Defaults to 1e-5. - kernel_initializer (str or initializer, optional): Initializer for dense layers. Defaults to "glorot_uniform". - dropout (float, optional): Dropout rate. Defaults to 0.0. - **kwargs: Additional keyword arguments for the parent class. + intermediate_dim: int. Hidden size of the feedforward network. + num_query_heads: int. Number of query attention heads (32 for + StableLM-3B4E1T). + num_key_value_heads: int. Number of key/value attention heads (32 + for StableLM-3B4E1T). + rope_max_wavelength: float. Maximum wavelength for RoPE. Defaults + to 10000. + rope_scaling_factor: float. Scaling factor for RoPE. Defaults to 1.0. + rotary_percentage: float. Percentage of head dimensions for RoPE + (0.25 for StableLM). + activation: Activation for the feedforward network. Defaults to "silu". + layer_norm_epsilon: float. Epsilon for LayerNorm. Defaults to 1e-5. + kernel_initializer: Initializer for dense layers. Defaults to + "glorot_uniform". + dropout: float, optional. Dropout rate. Defaults to 0.0. """ def __init__( @@ -118,7 +123,9 @@ def build(self, decoder_sequence_shape): name="feedforward_output_dense", ) self._feedforward_output_dense.build( - self._feedforward_gate_dense.compute_output_shape(decoder_sequence_shape) + self._feedforward_gate_dense.compute_output_shape( + decoder_sequence_shape + ) ) # LayerNorm for feedforward (with learned bias) @@ -129,7 +136,7 @@ def build(self, decoder_sequence_shape): ) self._feedforward_layernorm.build(decoder_sequence_shape) - self.built = True + super().build(decoder_sequence_shape) def call( self, @@ -169,13 +176,14 @@ def call( gate_output = self._feedforward_gate_dense(x) gate_output = self.activation(gate_output) intermediate_output = self._feedforward_intermediate_dense(x) - x = self._feedforward_output_dense(ops.multiply(intermediate_output, gate_output)) + x = self._feedforward_output_dense( + ops.multiply(intermediate_output, gate_output) + ) decoder_output = x + residual - if self_attention_cache is not None: return decoder_output, self_attention_cache return decoder_output - + def _compute_self_attention_mask( self, decoder_sequence, @@ -191,11 +199,19 @@ def _compute_self_attention_mask( input_length = output_length = ops.shape(decoder_sequence)[1] if self_attention_cache is not None: input_length = ops.shape(self_attention_cache)[2] - cache_update_index = 0 if self_attention_cache_update_index is None else self_attention_cache_update_index + cache_update_index = ( + 0 if self_attention_cache_update_index is None + else self_attention_cache_update_index + ) causal_mask = compute_causal_mask( batch_size, input_length, output_length, cache_update_index ) - return ops.minimum(decoder_mask, causal_mask) if decoder_mask is not None else causal_mask + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + def compute_output_shape(self, decoder_sequence_shape): return decoder_sequence_shape @@ -212,8 +228,10 @@ def get_config(self): "rotary_percentage": self.rotary_percentage, "activation": keras.activations.serialize(self.activation), "layer_norm_epsilon": self.layer_norm_epsilon, - "kernel_initializer": keras.initializers.serialize(self.kernel_initializer), + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), "dropout": self.dropout, } ) - return config \ No newline at end of file + return config diff --git a/keras_hub/src/models/stablelm/stablelm_presets.py b/keras_hub/src/models/stablelm/stablelm_presets.py new file mode 100644 index 0000000000..cf3e1838f1 --- /dev/null +++ b/keras_hub/src/models/stablelm/stablelm_presets.py @@ -0,0 +1,10 @@ +backbone_presets = { + "stablelm_3b_4e1t_en": { + "metadata": { + "description": "3 billion parameter,32-layer, base StableLM model.", + "params": 2795443200, + "path": "stablelm", + }, + "kaggle_handle": "xxxxxx", + }, +} \ No newline at end of file diff --git a/keras_hub/src/models/stablelm/stablelm_tokenizer.py b/keras_hub/src/models/stablelm/stablelm_tokenizer.py index b9bea16f92..c30bb1c5b9 100644 --- a/keras_hub/src/models/stablelm/stablelm_tokenizer.py +++ b/keras_hub/src/models/stablelm/stablelm_tokenizer.py @@ -28,8 +28,8 @@ class StableLMTokenizer(BytePairTokenizer): vocabulary: string or dict, maps tokens to integer IDs. If a string, it should be the file path to a JSON file containing the vocabulary. merges: string or list, contains the merge rules. If a string, it should - be the file path to a file with merge rules, where each line contains - merge entities separated by a space. + be the file path to a file with merge rules, where each + line contains merge entities separated by a space. """ backbone_cls = StableLMBackbone @@ -40,8 +40,8 @@ def __init__( merges=None, **kwargs, ): - # StableLM uses the GPT-NeoX tokenizer, which has "<|endoftext|>" as both - # start and end token. + # StableLM uses the GPT-NeoX tokenizer, which has + # "<|endoftext|>" as both start and end token. self._add_special_token("<|endoftext|>", "end_token") self._add_special_token("<|endoftext|>", "start_token") self.pad_token_id = 0 diff --git a/keras_hub/src/models/stablelm/stablelm_tokenizer_test.py b/keras_hub/src/models/stablelm/stablelm_tokenizer_test.py index 343fbb3f5a..89ae844a59 100644 --- a/keras_hub/src/models/stablelm/stablelm_tokenizer_test.py +++ b/keras_hub/src/models/stablelm/stablelm_tokenizer_test.py @@ -1,11 +1,13 @@ -import keras from keras_hub.src.models.stablelm.stablelm_tokenizer import StableLMTokenizer from keras_hub.src.tests.test_case import TestCase + class StableLMTokenizerTest(TestCase): def setUp(self): - self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port", "<|endoftext|>"] + self.vocab = [ + "!", "air", "Ġair", "plane", "Ġat", "port", "<|endoftext|>" + ] self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)]) self.merges = [ "Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e", @@ -30,6 +32,7 @@ def test_tokenizer_basics(self): ) def test_errors_missing_special_tokens(self): - # Test that an error is raised if "<|endoftext|>" is missing from the vocabulary + # Test that an error is raised if "<|endoftext|>" is + # missing from the vocabulary with self.assertRaises(ValueError): StableLMTokenizer(vocabulary=["a", "b", "c"], merges=[]) \ No newline at end of file diff --git a/tools/checkpoint_conversion/convert_stablelm_checkpoints.py b/tools/checkpoint_conversion/convert_stablelm_checkpoints.py new file mode 100644 index 0000000000..6a91886f08 --- /dev/null +++ b/tools/checkpoint_conversion/convert_stablelm_checkpoints.py @@ -0,0 +1,262 @@ +import json +import os + +import jax.numpy as jnp +import keras +import numpy as np +import requests +import tensorflow as tf +import torch +from tqdm import tqdm +from transformers import AutoModel +from transformers import AutoTokenizer + +from keras_hub.src.models.stablelm.stablelm_backbone import StableLMBackbone + +# Set the desired Keras backend (e.g., "torch", "tensorflow", "jax") +os.environ["KERAS_BACKEND"] = "torch" + +# Detect and verify the current Keras backend +backend = keras.backend.backend() +print(f"Current Keras backend: {backend}") + +# Configuration +PRESET_NAME = "stablelm-3b-4e1t" +BASE_MODEL = "stabilityai/stablelm-3b-4e1t" +EXTRACT_DIR = "./{}" + +extract_dir = EXTRACT_DIR.format(PRESET_NAME) +if not os.path.exists(extract_dir): + os.makedirs(extract_dir) + +# Function to download files with progress bar +def download_file(url, filepath): + response = requests.get(url, stream=True) + response.raise_for_status() + total_size = int(response.headers.get('content-length', 0)) + with open(filepath, 'wb') as f, tqdm( + desc=os.path.basename(filepath), + total=total_size, + unit='B', + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=8192): + size = f.write(chunk) + bar.update(size) + +# Download vocab and merges +vocab_path = os.path.join(extract_dir, "vocab.json") +merges_path = os.path.join(extract_dir, "merges.txt") +tokenizer_url = f"https://huggingface.co/{BASE_MODEL}/raw/main/tokenizer.json" +download_file(tokenizer_url, os.path.join(extract_dir, "tokenizer.json")) +with open(os.path.join(extract_dir, "tokenizer.json"), "r") as f: + tokenizer_data = json.load(f) +vocab = { + token: idx for idx, token in enumerate(tokenizer_data["model"]["vocab"]) +} +with open(vocab_path, "w") as f: + json.dump(vocab, f) +merges = tokenizer_data["model"]["merges"] +with open(merges_path, "w") as f: + for merge in merges: + f.write(merge + "\n") + +# Download config +config_path = os.path.join(extract_dir, "config.json") +config_url = f"https://huggingface.co/{BASE_MODEL}/raw/main/config.json" +download_file(config_url, config_path) +cfg = {} +with open(config_path, "r") as pt_cfg_handler: + pt_cfg = json.load(pt_cfg_handler) + +cfg["vocabulary_size"] = pt_cfg["vocab_size"] +cfg["num_layers"] = pt_cfg["num_hidden_layers"] +cfg["num_query_heads"] = pt_cfg["num_attention_heads"] +cfg["num_key_value_heads"] = pt_cfg["num_key_value_heads"] +cfg["hidden_dim"] = pt_cfg["hidden_size"] +cfg["intermediate_dim"] = pt_cfg["intermediate_size"] +cfg["max_sequence_length"] = pt_cfg["max_position_embeddings"] +cfg["layer_norm_epsilon"] = pt_cfg["layer_norm_eps"] +cfg["rope_max_wavelength"] = pt_cfg["rope_theta"] +cfg["partial_rotary_factor"] = pt_cfg["partial_rotary_factor"] + +# Load Hugging Face model +hf_model = AutoModel.from_pretrained(BASE_MODEL) +hf_model.eval() +hf_wts = hf_model.state_dict() + +# Initialize Keras model +keras_model = StableLMBackbone(**cfg) + +# Function to convert tensors to NumPy based on tensor type +def to_numpy(tensor): + if isinstance(tensor, torch.Tensor): + return tensor.detach().cpu().numpy() + elif isinstance(tensor, tf.Tensor): + return tensor.numpy() + elif isinstance(tensor, jnp.ndarray): + return np.array(tensor) + else: + raise ValueError(f"Unsupported tensor type: {type(tensor)}") + +# Transfer weights +keras_model.get_layer("token_embedding").embeddings.assign( + to_numpy(hf_model.embed_tokens.weight) +) + +for layer_index in range(cfg["num_layers"]): + hidden_size = cfg["hidden_dim"] + num_attention_heads = cfg["num_query_heads"] + num_key_value_heads = cfg["num_key_value_heads"] + head_dim = hidden_size // num_attention_heads + + # Query projection + q_weight_key = f"layers.{layer_index}.self_attn.q_proj.weight" + q_weight_hf = to_numpy(hf_wts[q_weight_key]) + q_weight = q_weight_hf.T.reshape(hidden_size, num_attention_heads, head_dim) + weights = [q_weight] + q_bias_key = f"layers.{layer_index}.self_attn.q_proj.bias" + if q_bias_key in hf_wts: + q_bias = to_numpy(hf_wts[q_bias_key]) + weights.append(q_bias) + keras_model.get_layer( + f"transformer_layer_{layer_index}" + )._self_attention_layer._query_dense.set_weights(weights) + + # Key projection + k_weight_key = f"layers.{layer_index}.self_attn.k_proj.weight" + k_weight_hf = to_numpy(hf_wts[k_weight_key]) + k_weight = k_weight_hf.T.reshape(hidden_size, num_key_value_heads, head_dim) + weights = [k_weight] + k_bias_key = f"layers.{layer_index}.self_attn.k_proj.bias" + if k_bias_key in hf_wts: + k_bias = to_numpy(hf_wts[k_bias_key]) + weights.append(k_bias) + keras_model.get_layer( + f"transformer_layer_{layer_index}" + )._self_attention_layer._key_dense.set_weights(weights) + + # Value projection + v_weight_key = f"layers.{layer_index}.self_attn.v_proj.weight" + v_weight_hf = to_numpy(hf_wts[v_weight_key]) + v_weight = v_weight_hf.T.reshape(hidden_size, num_key_value_heads, head_dim) + weights = [v_weight] + v_bias_key = f"layers.{layer_index}.self_attn.v_proj.bias" + if v_bias_key in hf_wts: + v_bias = to_numpy(hf_wts[v_bias_key]) + weights.append(v_bias) + keras_model.get_layer( + f"transformer_layer_{layer_index}" + )._self_attention_layer._value_dense.set_weights(weights) + + # Output projection + o_weight_key = f"layers.{layer_index}.self_attn.o_proj.weight" + o_weight_hf = to_numpy(hf_wts[o_weight_key]) + o_weight = o_weight_hf.T.reshape(num_attention_heads, head_dim, hidden_size) + weights = [o_weight] + o_bias_key = f"layers.{layer_index}.self_attn.o_proj.bias" + if o_bias_key in hf_wts: + o_bias = to_numpy(hf_wts[o_bias_key]) + weights.append(o_bias) + keras_model.get_layer( + f"transformer_layer_{layer_index}" + )._self_attention_layer._output_dense.set_weights(weights) + + # LayerNorms + ln_weight_key = f"layers.{layer_index}.input_layernorm.weight" + ln_bias_key = f"layers.{layer_index}.input_layernorm.bias" + keras_model.get_layer( + f"transformer_layer_{layer_index}" + )._self_attention_layernorm.set_weights( + [ + to_numpy(hf_wts[ln_weight_key]), + to_numpy(hf_wts[ln_bias_key]) + ] + ) + + ln_weight_key = f"layers.{layer_index}.post_attention_layernorm.weight" + ln_bias_key = f"layers.{layer_index}.post_attention_layernorm.bias" + keras_model.get_layer( + f"transformer_layer_{layer_index}" + )._feedforward_layernorm.set_weights( + [ + to_numpy(hf_wts[ln_weight_key]), + to_numpy(hf_wts[ln_bias_key]) + ] + ) + + # Feedforward + ff_gate_weight_key = f"layers.{layer_index}.mlp.gate_proj.weight" + ff_gate_weight_hf = to_numpy(hf_wts[ff_gate_weight_key]) + weights = [ff_gate_weight_hf.T] + ff_gate_bias_key = f"layers.{layer_index}.mlp.gate_proj.bias" + if ff_gate_bias_key in hf_wts: + ff_gate_bias = to_numpy(hf_wts[ff_gate_bias_key]) + weights.append(ff_gate_bias) + keras_model.get_layer( + f"transformer_layer_{layer_index}" + )._feedforward_gate_dense.set_weights(weights) + + ff_inter_weight_key = f"layers.{layer_index}.mlp.up_proj.weight" + ff_inter_weight_hf = to_numpy(hf_wts[ff_inter_weight_key]) + weights = [ff_inter_weight_hf.T] + ff_inter_bias_key = f"layers.{layer_index}.mlp.up_proj.bias" + if ff_inter_bias_key in hf_wts: + ff_inter_bias = to_numpy(hf_wts[ff_inter_bias_key]) + weights.append(ff_inter_bias) + keras_model.get_layer( + f"transformer_layer_{layer_index}" + )._feedforward_intermediate_dense.set_weights(weights) + + ff_out_weight_key = f"layers.{layer_index}.mlp.down_proj.weight" + ff_out_weight_hf = to_numpy(hf_wts[ff_out_weight_key]) + weights = [ff_out_weight_hf.T] + ff_out_bias_key = f"layers.{layer_index}.mlp.down_proj.bias" + if ff_out_bias_key in hf_wts: + ff_out_bias = to_numpy(hf_wts[ff_out_bias_key]) + weights.append(ff_out_bias) + keras_model.get_layer( + f"transformer_layer_{layer_index}" + )._feedforward_output_dense.set_weights(weights) + +# Final LayerNorm +keras_model.get_layer("sequence_output_layernorm").set_weights( + [ + to_numpy(hf_wts["norm.weight"]), + to_numpy(hf_wts["norm.bias"]) + ] +) + +# Tokenization and comparison +hf_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) +sample_text = ["Royal Challengers Bangalore will be winning this IPL"] +hf_inputs = hf_tokenizer(sample_text, return_tensors="pt") +print("HF inputs:", hf_inputs) + +if backend == "torch": + token_ids = hf_inputs["input_ids"] + padding_mask = hf_inputs["attention_mask"] +elif backend == "tensorflow": + token_ids = tf.convert_to_tensor(hf_inputs["input_ids"].numpy()) + padding_mask = tf.convert_to_tensor(hf_inputs["attention_mask"].numpy()) +elif backend == "jax": + token_ids = jnp.array(hf_inputs["input_ids"].numpy()) + padding_mask = jnp.array(hf_inputs["attention_mask"].numpy()) +else: + raise ValueError(f"Unsupported backend: {backend}") + +keras_inputs = {"token_ids": token_ids, "padding_mask": padding_mask} +keras_outputs = keras_model(keras_inputs) +keras_outputs_np = to_numpy(keras_outputs) +print("Keras output:", keras_outputs_np) + +hf_outputs = hf_model(**hf_inputs).last_hidden_state +hf_outputs_np = hf_outputs.detach().cpu().numpy() +print("HF output:", hf_outputs_np) + +try: + np.testing.assert_allclose(hf_outputs_np, keras_outputs_np, atol=1e-3) + print("✅ Model outputs match!") +except AssertionError: + print("❌ Model outputs differ!") \ No newline at end of file