From a94ed20fc44009c7075179d8cbd05f1b99c18991 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali Date: Fri, 19 Sep 2025 20:30:17 +0530 Subject: [PATCH 1/9] Add STU layer --- keras_rs/src/layers/common.py | 62 +++++ keras_rs/src/layers/hstu_compute_output.py | 125 ++++++++++ keras_rs/src/layers/hstu_mha_attention.py | 109 +++++++++ .../src/layers/hstu_preprocess_attention.py | 44 ++++ keras_rs/src/layers/hstu_uqvk_output.py | 81 +++++++ keras_rs/src/layers/jagged_tensors.py | 112 +++++++++ keras_rs/src/layers/stu.py | 222 ++++++++++++++++++ 7 files changed, 755 insertions(+) create mode 100644 keras_rs/src/layers/common.py create mode 100644 keras_rs/src/layers/hstu_compute_output.py create mode 100644 keras_rs/src/layers/hstu_mha_attention.py create mode 100644 keras_rs/src/layers/hstu_preprocess_attention.py create mode 100644 keras_rs/src/layers/hstu_uqvk_output.py create mode 100644 keras_rs/src/layers/jagged_tensors.py create mode 100644 keras_rs/src/layers/stu.py diff --git a/keras_rs/src/layers/common.py b/keras_rs/src/layers/common.py new file mode 100644 index 0000000..4e5b9f5 --- /dev/null +++ b/keras_rs/src/layers/common.py @@ -0,0 +1,62 @@ +import keras +from keras import ops +from typing import List, Optional, Tuple + +def fx_unwrap_optional_tensor(optional: Optional[keras.KerasTensor]) -> keras.KerasTensor: + """Helper to unwrap optional tensors, returning a zero-tensor for uninitialized cache.""" + if optional is None: + # Returning a zero-tensor is necessary for graph tracing when the cache is uninitialized. + return ops.zeros((0,), dtype='float32') + return optional + +def get_valid_attn_mask_keras( + causal: bool, + N: int, + seq_lengths: keras.KerasTensor, + num_targets: Optional[keras.KerasTensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + min_full_attn_seq_len: int = 0, +): + """ + Keras implementation of the valid attention mask generation, combining + causality, sequence lengths, and target awareness. + """ + ids = ops.reshape(ops.arange(0, N, dtype="int32"), (1, N)) + max_ids = ops.reshape(seq_lengths, (-1, 1, 1)) + B = ops.shape(seq_lengths)[0] + + if contextual_seq_len > 0: + ids = ids - contextual_seq_len + 1 + ids = ops.maximum(ids, 0) + max_ids = max_ids - contextual_seq_len + 1 + + if num_targets is not None: + max_ids = max_ids - ops.reshape(num_targets, (-1, 1, 1)) + ids = ops.minimum(ids, max_ids) + row_ids = ops.broadcast_to(ops.reshape(ids, (-1, N, 1)), (B, N, N)) + col_ids = ops.broadcast_to(ops.reshape(ids, (-1, 1, N)), (B, N, N)) + else: + row_ids = ops.broadcast_to(ops.reshape(ids, (N, 1)), (N, N)) + col_ids = ops.transpose(row_ids) + row_ids = ops.reshape(row_ids, (1, N, N)) + col_ids = ops.reshape(col_ids, (1, N, N)) + max_ids = None + + row_col_dist = row_ids - col_ids + valid_attn_mask = ops.reshape(ops.eye(N, dtype="bool"), (1, N, N)) + + if not causal: + row_col_dist = ops.where(row_col_dist > 0, row_col_dist, -row_col_dist) + + valid_attn_mask = ops.logical_or(valid_attn_mask, row_col_dist > 0) + + if max_attn_len > 0: + valid_attn_mask = ops.logical_and(valid_attn_mask, row_col_dist <= max_attn_len) + + if contextual_seq_len > 0 and max_ids is not None: + valid_attn_mask = ops.logical_or( + valid_attn_mask, ops.logical_and(row_ids == 0, col_ids < max_ids) + ) + + return valid_attn_mask diff --git a/keras_rs/src/layers/hstu_compute_output.py b/keras_rs/src/layers/hstu_compute_output.py new file mode 100644 index 0000000..9d35959 --- /dev/null +++ b/keras_rs/src/layers/hstu_compute_output.py @@ -0,0 +1,125 @@ +import keras +from keras import ops +from typing import List, Optional, Tuple + +def keras_norm_mul_dropout( + x: keras.KerasTensor, + u: keras.KerasTensor, + weight: keras.KerasTensor, + bias: keras.KerasTensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + group_norm: bool = False, + num_heads: int = 1, + linear_dim: int = -1, +) -> keras.KerasTensor: + """ + Keras 3 equivalent of pytorch_norm_mul_dropout. + Applies normalization, element-wise multiplication with u, and dropout. + Assumes keras_layer_norm is available (though the logic is inlined here). + """ + x = ops.convert_to_tensor(x, dtype='float32') + u = ops.convert_to_tensor(u, dtype='float32') + + if silu_u: + u = ops.silu(u) + + if group_norm: + raise NotImplementedError("Group Norm path not suitable for simple Keras ops conversion.") + else: + # Functional Layer Normalization (Simulated keras_layer_norm) + mean = ops.mean(x, axis=-1, keepdims=True) + variance = ops.mean(ops.square(x - mean), axis=-1, keepdims=True) + x_norm = (x - mean) / ops.sqrt(variance + eps) + + # Apply weight and bias (Gamma * x_norm + Beta) + y_norm = x_norm * weight + bias + + # Apply u multiplication (Element-wise gating) + y = u * y_norm + + if concat_ux: + y = ops.concatenate([u, x, y], axis=1) + + # Dropout (using Keras layer for correct training=True/False behavior) + y = keras.layers.Dropout(dropout_ratio)(y, training=training) + + return ops.cast(y, dtype=x.dtype) + +def keras_hstu_compute_output( + attn: keras.KerasTensor, + u: keras.KerasTensor, + x: keras.KerasTensor, + norm_weight: keras.KerasTensor, + norm_bias: keras.KerasTensor, + output_weight: keras.KerasTensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + group_norm: bool = False, + num_heads: int = 1, + linear_dim: int = -1, +) -> keras.KerasTensor: + """ + Core kernel for the final residual block calculation (Attn Output -> Norm/Dropout -> MatMul -> Residual Add). + """ + y = keras_norm_mul_dropout( + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + eps=eps, + dropout_ratio=dropout_ratio, + training=training, + silu_u=silu_u, + concat_ux=concat_ux, + group_norm=group_norm, + num_heads=num_heads, + linear_dim=linear_dim, + ) + + # Final output: Residual addition of input (x) and transformed attention output (y @ output_weight) + output = ops.add(x, ops.matmul(y, output_weight)) + + return output + +def hstu_compute_output( + attn: keras.KerasTensor, + u: keras.KerasTensor, + x: keras.KerasTensor, + norm_weight: keras.KerasTensor, + norm_bias: keras.KerasTensor, + norm_eps: float, + output_weight: keras.KerasTensor, + num_heads: int, + linear_dim: int, + dropout_ratio: float, + training: bool, + concat_ux: bool, + group_norm: bool, + recompute_y_in_backward: bool, +) -> keras.KerasTensor: + """ + Top-level wrapper for the output computation, delegates to the core Keras kernel. + """ + return keras_hstu_compute_output( + attn=attn, + u=u, + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + output_weight=output_weight, + eps=norm_eps, + dropout_ratio=dropout_ratio, + training=training, + silu_u=False, + concat_ux=concat_ux, + group_norm=group_norm, + num_heads=num_heads, + linear_dim=linear_dim, + ) diff --git a/keras_rs/src/layers/hstu_mha_attention.py b/keras_rs/src/layers/hstu_mha_attention.py new file mode 100644 index 0000000..d7b146d --- /dev/null +++ b/keras_rs/src/layers/hstu_mha_attention.py @@ -0,0 +1,109 @@ +import keras +from keras import ops +from typing import Tuple, Optional +from keras import layers + +# --- Assumed Imports --- +# Assumes keras_jagged_to_padded_dense, keras_dense_to_jagged, and get_valid_attn_mask_keras are available from other modules. + +def keras_pad_qkv( + q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, N: int, +) -> Tuple[keras.KerasTensor, keras.KerasTensor, keras.KerasTensor]: + """ + Helper to pad Q, K, V from jagged to dense format for MHA. + Assumes keras_jagged_to_padded_dense is available globally. + """ + L, H, D = ops.shape(q); V_dim = ops.shape(v)[2] + values_q_k = ops.reshape(q, [L, H * D]); values_v = ops.reshape(v, [L, H * V_dim]) + + # Pad Q, K, V + padded_q_k = keras_jagged_to_padded_dense(values=values_q_k, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) + padded_v = keras_jagged_to_padded_dense(values=values_v, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) + + B = ops.shape(padded_q_k)[0]; padded_q_k = ops.reshape(padded_q_k, [B, N, H, D]); padded_v = ops.reshape(padded_v, [B, N, H, V_dim]) + padded_q = ops.transpose(padded_q_k, [0, 2, 1, 3]); padded_k = ops.transpose(padded_q_k, [0, 2, 1, 3]) + padded_v = ops.transpose(padded_v, [0, 2, 1, 3]) + return padded_q, padded_k, padded_v + + +def keras_hstu_mha( + max_seq_len: int, alpha: float, q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, causal: bool = True, dropout_pr: float = 0.0, training: bool = True, attn_scale: Optional[keras.KerasTensor] = None, **kwargs +) -> keras.KerasTensor: + """Core Keras implementation of the full Multi-Head Attention kernel (Non-Cached).""" + L, H, _ = ops.shape(q); V_dim = ops.shape(v)[2] + q, k, v = keras_pad_qkv(q, k, v, seq_offsets, max_seq_len) + qk_attn = ops.einsum("bhxa,bhya->bhxy", q, k) * alpha + + # Activation and Scaling + if attn_scale is not None: + if ops.ndim(attn_scale) > 0: + attn_scale_padded = keras_jagged_to_padded_dense(values=ops.expand_dims(attn_scale, axis=-1), offsets=[seq_offsets], max_lengths=[max_seq_len], padding_value=0.0) + attn_scale_padded = ops.expand_dims(ops.cast(attn_scale_padded, qk_attn.dtype), axis=1) + qk_attn = ops.silu(qk_attn) * attn_scale_padded + else: + qk_attn = ops.silu(qk_attn) / max_seq_len + + # Masking + seq_lengths = seq_offsets[1:] - seq_offsets[:-1] + valid_attn_mask = get_valid_attn_mask_keras(causal=causal, N=max_seq_len, seq_lengths=seq_lengths, **kwargs) + qk_attn = qk_attn * ops.expand_dims(ops.cast(valid_attn_mask, qk_attn.dtype), axis=1) + + # Dropout + if dropout_pr > 0.0 and training: + qk_attn = keras.layers.Dropout(dropout_pr)(qk_attn, training=training) + + # Output (Weighted Sum) + attn_dense = ops.einsum("bhxd,bhdv->bhxv", qk_attn, v) + flat_attn_dense = ops.reshape(ops.transpose(attn_dense, [0, 2, 1, 3]), [-1, max_seq_len, H * V_dim]) + + # Convert back to jagged + jagged_output = keras_dense_to_jagged(flat_attn_dense, [seq_offsets]) + L_out = ops.shape(jagged_output)[0] + return ops.reshape(jagged_output, [L_out, H, V_dim]) + + +def keras_cached_hstu_mha( + max_seq_len: int, alpha: float, delta_q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, num_targets: Optional[keras.KerasTensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, enable_tma: bool = False, +) -> keras.KerasTensor: + """Core Keras implementation of the cached attention kernel (Delta Q attends to Full K/V).""" + L_delta, H, D = ops.shape(delta_q); B = ops.shape(seq_offsets)[0] - 1; DeltaSize = L_delta // B; V_dim = ops.shape(v)[2] + + # 1. Reshape Delta Q + delta_q = ops.transpose(ops.reshape(delta_q, (B, DeltaSize, H, D)), perm=[0, 2, 1, 3]) + + # 2. Reshape Full K and V (Inputs k, v are already flattened/jagged-like) + N_full = max_seq_len + k_full = ops.transpose(ops.reshape(k, (B, N_full, H, D)), [0, 2, 1, 3]) + v_full = ops.transpose(ops.reshape(v, (B, N_full, H, V_dim)), [0, 2, 1, 3]) + + # 3. Attention Score and Activation + qk_attn = ops.einsum("bhxa,bhya->bhxy", delta_q, k_full) * alpha + qk_attn = ops.silu(qk_attn) / max_seq_len + + # 4. Masking (Slice the mask to select only the rows corresponding to the new queries) + seq_lengths = seq_offsets[1:] - seq_offsets[:-1] + full_valid_attn_mask = get_valid_attn_mask_keras(causal=True, N=max_seq_len, seq_lengths=seq_lengths, num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len) + valid_attn_mask_sliced = full_valid_attn_mask[:, -DeltaSize:, :] + + qk_attn = qk_attn * ops.expand_dims(ops.cast(valid_attn_mask_sliced, qk_attn.dtype), axis=1) + + # 5. Output (Weighted Sum) + attn_output = ops.einsum("bhxd,bhdv->bhxv", qk_attn, v_full) + + # 6. Reshape and return [L_delta, H, V_dim] + attn_output = ops.transpose(attn_output, perm=[0, 2, 1, 3]) + return ops.reshape(attn_output, (-1, H, V_dim)) + + +def delta_hstu_mha( + max_seq_len: int, alpha: float, delta_q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, num_targets: Optional[keras.KerasTensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, kernel=None, enable_tma: bool = False, +) -> keras.KerasTensor: + """Top-level wrapper for cached inference MHA (delegates to core cached kernel).""" + + L_delta, H, D = ops.shape(delta_q) + # Assertions are maintained by the layer/framework where possible. + + return keras_cached_hstu_mha( + max_seq_len=max_seq_len, alpha=alpha, delta_q=delta_q, k=k, v=v, seq_offsets=seq_offsets, + num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len, + ) diff --git a/keras_rs/src/layers/hstu_preprocess_attention.py b/keras_rs/src/layers/hstu_preprocess_attention.py new file mode 100644 index 0000000..04f70d3 --- /dev/null +++ b/keras_rs/src/layers/hstu_preprocess_attention.py @@ -0,0 +1,44 @@ +import keras +from keras import ops +from typing import Tuple, List, Optional + + +def keras_hstu_preprocess_and_attention( + x: keras.KerasTensor, norm_weight: keras.KerasTensor, norm_bias: keras.KerasTensor, norm_eps: float, num_heads: int, attn_dim: int, hidden_dim: int, + uvqk_weight: keras.KerasTensor, uvqk_bias: keras.KerasTensor, max_seq_len: int, seq_offsets: keras.KerasTensor, attn_alpha: float, causal: bool, + num_targets: Optional[keras.KerasTensor], max_attn_len: int, contextual_seq_len: int, recompute_uvqk_in_backward: bool, + recompute_normed_x_in_backward: bool, sort_by_length: bool, prefill: bool = False, + kernel=None, **kwargs +) -> Tuple: + """ + Keras 3 implementation of the H-STU preprocess and attention workflow. + Orchestrates the conversion of input X into U, Q, K, V and subsequent MHA computation. + """ + + # --- Assertions (Skipped internal torch asserts, simplified to Keras asserts for context) --- + assert max_seq_len > 0, "max_seq_len must be larger than 0" + assert ops.ndim(x) == 2, "x must be 2-D" + assert causal is True, "only causal attention is supported." + + # 1. Compute U, Q, K, V + # Note: hstu_compute_uqvk handles the initial Norm, Linear Projection, and Split. + u, q, k, v = hstu_compute_uqvk( + x=x, norm_weight=norm_weight, norm_bias=norm_bias, norm_eps=norm_eps, + num_heads=num_heads, attn_dim=attn_dim, hidden_dim=hidden_dim, + uvqk_weight=uvqk_weight, uvqk_bias=uvqk_bias, kernel=kernel, + ) + + # 2. Compute Attention + attn_output = keras_hstu_mha( + max_seq_len=max_seq_len, alpha=attn_alpha, q=q, k=k, v=v, + seq_offsets=seq_offsets, causal=causal, dropout_pr=0.0, + training=False, num_targets=num_targets, max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, sort_by_length=sort_by_length, + kernel=kernel, **kwargs + ) + + # Reshape: [L, H, D] -> [L, H * D] (Flattening for the final hstu_compute_output block) + attn_output = ops.reshape(attn_output, [-1, hidden_dim * num_heads]) + + # Returns u (gating), attention output, k, and v (for caching) + return u, attn_output, k, v diff --git a/keras_rs/src/layers/hstu_uqvk_output.py b/keras_rs/src/layers/hstu_uqvk_output.py new file mode 100644 index 0000000..47b5a1a --- /dev/null +++ b/keras_rs/src/layers/hstu_uqvk_output.py @@ -0,0 +1,81 @@ +import keras +from keras import ops +from typing import List, Optional, Tuple + +def keras_layer_norm( + x: keras.KerasTensor, + weight: keras.KerasTensor, + bias: keras.KerasTensor, + eps: float, +) -> keras.KerasTensor: + """ + Keras 3 functional Layer Normalization implementation. + Simulates F.layer_norm where scale/bias is applied externally. + """ + # 1. Normalize x + mean = ops.mean(x, axis=-1, keepdims=True) + variance = ops.mean(ops.square(x - mean), axis=-1, keepdims=True) + x_norm = (x - mean) / ops.sqrt(variance + eps) + + # 2. Apply weight and bias (Gamma * x_norm + Beta) + return x_norm * weight + bias + +def keras_addmm( + bias: keras.KerasTensor, + input: keras.KerasTensor, + mat2: keras.KerasTensor, +) -> keras.KerasTensor: + """Keras 3 equivalent of torch.addmm (bias + input @ mat2).""" + return ops.add(bias, ops.matmul(input, mat2)) + +def hstu_compute_uqvk( + x: keras.KerasTensor, + norm_weight: keras.KerasTensor, + norm_bias: keras.KerasTensor, + norm_eps: float, + num_heads: int, + attn_dim: int, + hidden_dim: int, + uvqk_weight: keras.KerasTensor, + uvqk_bias: keras.KerasTensor, + kernel=None, +) -> Tuple[keras.KerasTensor, keras.KerasTensor, keras.KerasTensor, keras.KerasTensor]: + """ + Computes the transformed tensors U, V, Q, and K from the input X. + """ + + # 1. Normalization + normed_x = keras_layer_norm( + x, + weight=norm_weight, + bias=norm_bias, + eps=norm_eps, + ) + + # 2. Combined Linear Projection (uvqk = bias + normed_x @ uvqk_weight) + uvqk = keras_addmm(uvqk_bias, normed_x, uvqk_weight) + + # 3. Calculate split sizes and slice + u_size = hidden_dim * num_heads + v_size = hidden_dim * num_heads + q_size = attn_dim * num_heads + k_size = attn_dim * num_heads + + start_u = 0 + start_v = u_size + start_q = u_size + v_size + start_k = u_size + v_size + q_size + L_out = ops.shape(uvqk)[0] + + u = ops.slice(uvqk, start_indices=[0, start_u], shape=[L_out, u_size]) + v = ops.slice(uvqk, start_indices=[0, start_v], shape=[L_out, v_size]) + q = ops.slice(uvqk, start_indices=[0, start_q], shape=[L_out, q_size]) + k = ops.slice(uvqk, start_indices=[0, start_k], shape=[L_out, k_size]) + + # 4. Activation and Reshape + u = ops.silu(u) + q = ops.reshape(q, [-1, num_heads, attn_dim]) + k = ops.reshape(k, [-1, num_heads, attn_dim]) + v = ops.reshape(v, [-1, num_heads, hidden_dim]) + + return u, q, k, v diff --git a/keras_rs/src/layers/jagged_tensors.py b/keras_rs/src/layers/jagged_tensors.py new file mode 100644 index 0000000..87e96e6 --- /dev/null +++ b/keras_rs/src/layers/jagged_tensors.py @@ -0,0 +1,112 @@ +import keras +from keras import ops +from typing import List, Optional, Tuple + +# --- Core Jagged/Dense Conversion Functions --- + +def keras_jagged_to_padded_dense(values, offsets, max_lengths, padding_value=0.0): + """ + Keras 3 implementation to convert jagged tensor (values) into a padded dense tensor [B, N, D_flat]. + Required by MHA kernel padding (keras_pad_qkv). + """ + offsets = offsets[0] if isinstance(offsets, list) else offsets + B = ops.shape(offsets)[0] - 1 + max_len = max_lengths[0] + D_flat = ops.shape(values)[-1] + if ops.shape(values)[0] == 0: + return ops.full((B, max_len, D_flat), padding_value, dtype=values.dtype) + + def pad_one(i): + start = offsets[i]; end = offsets[i+1] + seq_len = end - start + seq = ops.slice(values, [start, 0], [seq_len, D_flat]) + if ops.equal(seq_len, 0): + return ops.full((max_len, D_flat), padding_value, dtype=values.dtype) + if seq_len < max_len: + padding_shape = ops.stack([max_len - seq_len, D_flat]) + padding = ops.full(padding_shape, padding_value, dtype=values.dtype) + return ops.concatenate([seq, padding], axis=0) + else: + return seq[:max_len] + + idxs = ops.arange(B, dtype='int32') + return ops.map(pad_one, idxs) + +def keras_dense_to_jagged( + dense: keras.KerasTensor, + x_offsets: List[keras.KerasTensor], +) -> keras.KerasTensor: + """Keras 3 implementation to convert a padded dense tensor [B, N, D] back into a jagged tensor.""" + seq_offsets = x_offsets[0] + N = ops.shape(dense)[1] + D_flat = ops.shape(dense)[2] + token_range = ops.arange(N) + seq_lengths = seq_offsets[1:] - seq_offsets[:-1] + mask = ops.expand_dims(token_range, axis=0) < ops.expand_dims(seq_lengths, axis=1) + + flattened = ops.reshape(dense, [-1, D_flat]) + flattened_mask = ops.reshape(mask, [-1]) + + return flattened[flattened_mask] + +# --- Jagged Splitting and Concatenation Wrappers (Used by Caching Logic) --- + +def split_2D_jagged( + max_seq_len: int, values: keras.KerasTensor, total_len_left: Optional[int] = None, total_len_right: Optional[int] = None, max_len_left: Optional[int] = None, max_len_right: Optional[int] = None, offsets_left: Optional[keras.KerasTensor] = None, offsets_right: Optional[keras.KerasTensor] = None, kernel=None, +) -> Tuple[keras.KerasTensor, keras.KerasTensor]: + """Top-level wrapper for splitting a concatenated jagged tensor.""" + + def keras_split_2D_jagged_jagged(max_seq_len, values, offsets_left, offsets_right): + D_flat = ops.shape(values)[1]; offsets = offsets_left + offsets_right + padded_values_bnd = keras_jagged_to_padded_dense(values=values, offsets=[offsets], max_lengths=[max_seq_len], padding_value=0.0) + padded_values = ops.reshape(padded_values_bnd, [-1, D_flat]) + lengths_left = offsets_left[1:] - offsets_left[:-1]; lengths_right = offsets_right[1:] - offsets_right[:-1] + mask = ops.reshape(ops.arange(max_seq_len, dtype='int32'), [1, -1]) + lengths_left_broadcast = ops.reshape(lengths_left, [-1, 1]); lengths_right_combined = ops.reshape(lengths_left + lengths_right, [-1, 1]) + mask_left = mask < lengths_left_broadcast + mask_right = ops.logical_and(mask >= lengths_left_broadcast, mask < lengths_right_combined) + return padded_values[ops.reshape(mask_left, [-1])], padded_values[ops.reshape(mask_right, [-1])] + + def keras_split_2D_jagged_resolver(max_seq_len, values, max_len_left, max_len_right, offsets_left, offsets_right): + L_total = ops.shape(values)[0] + offsets_left_non_optional = offsets_left + if offsets_left is None: offsets_left_non_optional = max_len_left * ops.arange(L_total // max_len_left + 1, dtype='int32') + offsets_right_non_optional = offsets_right + if offsets_right is None: offsets_right_non_optional = max_len_right * ops.arange(L_total // max_len_right + 1, dtype='int32') + return keras_split_2D_jagged_jagged(max_seq_len=max_seq_len, values=values, offsets_left=offsets_left_non_optional, offsets_right=offsets_right_non_optional) + + return keras_split_2D_jagged_resolver(max_seq_len=max_seq_len, values=values, max_len_left=max_len_left, max_len_right=max_len_right, offsets_left=offsets_left, offsets_right=offsets_right) + + +def concat_2D_jagged( + max_seq_len: int, values_left: keras.KerasTensor, values_right: keras.KerasTensor, max_len_left: Optional[int] = None, max_len_right: Optional[int] = None, offsets_left: Optional[keras.KerasTensor] = None, offsets_right: Optional[keras.KerasTensor] = None, kernel=None, +) -> keras.KerasTensor: + """Top-level wrapper for concatenating 2D jagged tensors (used for KV cache construction).""" + + def keras_concat_2D_jagged_jagged(values_left, values_right, max_len_left, max_len_right, offsets_left, offsets_right): + max_seq_len = max_len_left + max_len_right + lengths_left = offsets_left[1:] - offsets_left[:-1]; lengths_right = offsets_right[1:] - offsets_right[:-1] + padded_left = keras_jagged_to_padded_dense(values=values_left, offsets=[offsets_left], max_lengths=[max_len_left], padding_value=0.0) + padded_right = keras_jagged_to_padded_dense(values=values_right, offsets=[offsets_right], max_lengths=[max_len_right], padding_value=0.0) + concatted_dense = ops.concatenate([padded_left, padded_right], axis=1) + + lengths_left_broadcast = ops.reshape(lengths_left, [-1, 1]); lengths_right_broadcast = ops.reshape(lengths_right, [-1, 1]) + mask = ops.reshape(ops.arange(max_seq_len, dtype='int32'), [1, -1]) + mask = ops.logical_or(mask < lengths_left_broadcast, ops.logical_and(mask >= max_len_left, mask < max_len_left + lengths_right_broadcast)) + return concatted_dense[ops.reshape(mask, [-1])] + + def pytorch_concat_2D_jagged_resolver(values_left, values_right, max_len_left, max_len_right, offsets_left, offsets_right): + L_total = ops.shape(values_left)[0] + offsets_left_non_optional = offsets_left + if offsets_left is None: offsets_left_non_optional = max_len_left * ops.arange(L_total // max_len_left + 1, dtype='int32') + offsets_right_non_optional = offsets_right + if offsets_right is None: offsets_right_non_optional = max_len_right * ops.arange(L_total // max_len_right + 1, dtype='int32') + + if max_len_left is None: max_len_left_final = ops.max(offsets_left_non_optional[1:] - offsets_left_non_optional[:-1]) + else: max_len_left_final = max_len_left + if max_len_right is None: max_len_right_final = ops.max(offsets_right_non_optional[1:] - offsets_right_non_optional[:-1]) + else: max_len_right_final = max_len_right + + return keras_concat_2D_jagged_jagged(values_left=values_left, values_right=values_right, max_len_left=max_len_left_final, max_len_right=max_len_right_final, offsets_left=offsets_left_non_optional, offsets_right=offsets_right_non_optional) + + return pytorch_concat_2D_jagged_resolver(values_left=values_left, values_right=values_right, max_len_left=max_len_left, max_len_right=max_len_right, offsets_left=offsets_left, offsets_right=offsets_right) diff --git a/keras_rs/src/layers/stu.py b/keras_rs/src/layers/stu.py new file mode 100644 index 0000000..b9ee8c7 --- /dev/null +++ b/keras_rs/src/layers/stu.py @@ -0,0 +1,222 @@ +import abc +from typing import List, Optional, Tuple +import keras +from keras import ops +from keras import layers + +from keras_rs.src.layers.common import fx_unwrap_optional_tensor +from keras_rs.src.layers.hstu_compute_output import hstu_compute_uqvk, hstu_compute_output +from keras_rs.src.layers.hstu_preprocess_attention import keras_hstu_preprocess_and_attention +from keras_rs.src.layers.hstu_mha_attention import delta_hstu_mha +from keras_rs.src.layers.jagged_tensors import split_2D_jagged, concat_2D_jagged + + +class STULayerConfig: + def __init__(self, embedding_dim: int, num_heads: int, hidden_dim: int, attention_dim: int, + output_dropout_ratio: float = 0.3, causal: bool = True, target_aware: bool = True, + max_attn_len: Optional[int] = None, attn_alpha: Optional[float] = None, + use_group_norm: bool = False, recompute_normed_x: bool = True, + recompute_uvqk: bool = True, recompute_y: bool = True, + sort_by_length: bool = True, contextual_seq_len: int = 0): + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.attention_dim = attention_dim + self.output_dropout_ratio = output_dropout_ratio + self.causal = causal + self.target_aware = target_aware + self.max_attn_len = max_attn_len + self.attn_alpha = attn_alpha + self.use_group_norm = use_group_norm + self.recompute_normed_x = recompute_normed_x + self.recompute_uvqk = recompute_uvqk + self.recompute_y = recompute_y + self.sort_by_length = sort_by_length + self.contextual_seq_len = contextual_seq_len + + +def _update_kv_cache( + max_seq_len: int, seq_offsets: keras.KerasTensor, k: Optional[keras.KerasTensor], v: Optional[keras.KerasTensor], max_kv_caching_len: int, kv_caching_lengths: Optional[keras.KerasTensor], orig_k_cache: Optional[keras.KerasTensor], orig_v_cache: Optional[keras.KerasTensor], orig_max_kv_caching_len: int, orig_kv_caching_offsets: Optional[keras.KerasTensor], +) -> Tuple[Optional[keras.KerasTensor], Optional[keras.KerasTensor], int, Optional[keras.KerasTensor]]: + + if kv_caching_lengths is not None: + # Keras equivalent of asynchronous_complete_cumsum + kv_caching_offsets = ops.cast(ops.cumsum(kv_caching_lengths, exclusive=True), dtype="int32") + delta_offsets = seq_offsets - kv_caching_offsets + + # NOTE: split_2D_jagged is available from jagged_tensors.py + k_cache, _ = split_2D_jagged(max_seq_len=max_seq_len, values=ops.reshape(fx_unwrap_optional_tensor(k), [-1, ops.shape(k)[-1]]), max_len_left=None, max_len_right=None, offsets_left=kv_caching_offsets, offsets_right=delta_offsets) + v_cache, _ = split_2D_jagged(max_seq_len=max_seq_len, values=ops.reshape(fx_unwrap_optional_tensor(v), [-1, ops.shape(v)[-1]]), max_len_left=None, max_len_right=None, offsets_left=kv_caching_offsets, offsets_right=delta_offsets) + + if max_kv_caching_len == 0: + max_kv_caching_len = ops.convert_to_numpy(ops.cast(ops.max(kv_caching_lengths), dtype="int32")).item() + return (k_cache, v_cache, max_kv_caching_len, kv_caching_offsets) + else: + return (orig_k_cache, orig_v_cache, orig_max_kv_caching_len, orig_kv_caching_offsets) + + +def _construct_full_kv( + delta_k: keras.KerasTensor, delta_v: keras.KerasTensor, k_cache: keras.KerasTensor, v_cache: keras.KerasTensor, max_kv_caching_len: int, kv_caching_offsets: keras.KerasTensor, +) -> Tuple[keras.KerasTensor, keras.KerasTensor, int, keras.KerasTensor]: + L = ops.shape(delta_k)[0] + B = ops.shape(kv_caching_offsets)[0] - 1 + delta_size = L // B + + # NOTE: concat_2D_jagged is available from jagged_tensors.py + full_k = concat_2D_jagged(max_seq_len=max_kv_caching_len + delta_size, values_left=k_cache, values_right=delta_k, max_len_left=max_kv_caching_len, max_len_right=delta_size, offsets_left=kv_caching_offsets, offsets_right=None) + full_v = concat_2D_jagged(max_seq_len=max_kv_caching_len + delta_size, values_left=v_cache, values_right=delta_v, max_len_left=max_kv_caching_len, max_len_right=delta_size, offsets_left=kv_caching_offsets, offsets_right=None) + + # Calculate new combined offsets + delta_size_broadcast = delta_size * ops.arange(B + 1, dtype=kv_caching_offsets.dtype) + full_kv_caching_offsets = kv_caching_offsets + delta_size_broadcast + + return (full_k, full_v, max_kv_caching_len + delta_size, full_kv_caching_offsets) + + +class STU(layers.Layer, abc.ABC): + """Abstract base class for STU layers.""" + @abc.abstractmethod + def cached_forward(self, delta_x: keras.KerasTensor, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None,) -> keras.KerasTensor: pass + @abc.abstractmethod + def call(self, x: keras.KerasTensor, x_lengths: keras.KerasTensor, x_offsets: keras.KerasTensor, max_seq_len: int, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None,) -> keras.KerasTensor: pass + + +class STULayer(layers.Layer): + # Initialize cache properties on the instance + max_kv_caching_len: int = 0 + k_cache: Optional[keras.KerasTensor] = None + v_cache: Optional[keras.KerasTensor] = None + kv_caching_offsets: Optional[keras.KerasTensor] = None + + def __init__(self, config: STULayerConfig, is_inference: bool = False, **kwargs): + super().__init__(**kwargs) + self._config = config + self._num_heads: int = config.num_heads + self._embedding_dim: int = config.embedding_dim + self._hidden_dim: int = config.hidden_dim + self._attention_dim: int = config.attention_dim + self._output_dropout_ratio: float = config.output_dropout_ratio + self._target_aware: bool = config.target_aware + self._causal: bool = config.causal + self._max_attn_len: int = config.max_attn_len or 0 + self._attn_alpha: float = config.attn_alpha or 1.0 / (self._attention_dim**0.5) + self._use_group_norm: bool = config.use_group_norm + self._recompute_normed_x: bool = config.recompute_normed_x + self._recompute_uvqk: bool = config.recompute_uvqk + self._recompute_y: bool = config.recompute_y + self._sort_by_length: bool = config.sort_by_length + self._contextual_seq_len: int = config.contextual_seq_len + self.reset_kv_cache() + + def build(self, input_shape): + D_in = input_shape[-1] + H = self._num_heads; A = self._attention_dim; V = self._hidden_dim + output_dim_total = (V * 2 + A * 2) * H + self._uvqk_weight = self.add_weight(shape=(D_in, output_dim_total), initializer='glorot_uniform', name='uvqk_weight') + self._uvqk_beta = self.add_weight(shape=(output_dim_total,), initializer='zeros', name='uvqk_beta') + self._input_norm_weight = self.add_weight(shape=(D_in,), initializer='ones', name='input_norm_weight') + self._input_norm_bias = self.add_weight(shape=(D_in,), initializer='zeros', name='input_norm_bias') + + self._output_weight = self.add_weight(shape=(V * H, self._embedding_dim), initializer='glorot_uniform', name='output_weight') + + output_norm_shape: int = (V * H if not self._use_group_norm else H) + self._output_norm_weight = self.add_weight(shape=(output_norm_shape,), initializer='ones', name='output_norm_weight') + self._output_norm_bias = self.add_weight(shape=(output_norm_shape,), initializer='zeros', name='output_norm_bias') + self.built = True + + def reset_kv_cache(self) -> None: + self.k_cache = None; self.v_cache = None + self.kv_caching_offsets = None; self.max_kv_caching_len = 0 + + def update_kv_cache( + self, max_seq_len: int, seq_offsets: keras.KerasTensor, k: Optional[keras.KerasTensor], v: Optional[keras.KerasTensor], max_kv_caching_len: int, kv_caching_lengths: Optional[keras.KerasTensor], + ) -> None: + # NOTE: Assumes _update_kv_cache is available + self.k_cache, self.v_cache, self.max_kv_caching_len, self.kv_caching_offsets = (_update_kv_cache(max_seq_len=max_seq_len, seq_offsets=seq_offsets, k=k, v=v, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths, orig_k_cache=self.k_cache, orig_v_cache=self.v_cache, orig_max_kv_caching_len=self.max_kv_caching_len, orig_kv_caching_offsets=self.kv_caching_offsets,)) + + def construct_full_kv(self, delta_k: keras.KerasTensor, delta_v: keras.KerasTensor,) -> Tuple[keras.KerasTensor, keras.KerasTensor, int, keras.KerasTensor]: + # NOTE: Assumes _construct_full_kv is available + return _construct_full_kv(delta_k=delta_k, delta_v=delta_v, k_cache=fx_unwrap_optional_tensor(self.k_cache), v_cache=fx_unwrap_optional_tensor(self.v_cache), max_kv_caching_len=self.max_kv_caching_len, kv_caching_offsets=fx_unwrap_optional_tensor(self.kv_caching_offsets),) + + def call( # Standard Keras forward method + self, x: keras.KerasTensor, x_lengths: keras.KerasTensor, x_offsets: keras.KerasTensor, max_seq_len: int, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None, + ) -> keras.KerasTensor: + + u, attn_output, k, v = keras_hstu_preprocess_and_attention( + x=x, norm_weight=self._input_norm_weight, norm_bias=self._input_norm_bias, norm_eps=1e-6, + num_heads=self._num_heads, attn_dim=self._attention_dim, hidden_dim=self._hidden_dim, + uvqk_weight=self._uvqk_weight, uvqk_bias=self._uvqk_beta, + max_seq_len=max_seq_len, seq_offsets=x_offsets, attn_alpha=self._attn_alpha, + causal=self._causal, num_targets=num_targets if self._target_aware else None, + max_attn_len=self._max_attn_len, contextual_seq_len=self._contextual_seq_len, + recompute_uvqk_in_backward=self._recompute_uvqk, recompute_normed_x_in_backward=self._recompute_normed_x, + sort_by_length=self._sort_by_length, prefill=kv_caching_lengths is not None, + ) + + self.update_kv_cache(max_seq_len=max_seq_len, seq_offsets=x_offsets, k=k, v=v, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths) + + return hstu_compute_output( + attn=attn_output, u=u, x=x, norm_weight=self._output_norm_weight, norm_bias=self._output_norm_bias, + norm_eps=1e-6, dropout_ratio=self._output_dropout_ratio, output_weight=self._output_weight, + group_norm=self._use_group_norm, num_heads=self._num_heads, linear_dim=self._hidden_dim, + concat_ux=True, training=training, recompute_y_in_backward=self._recompute_y, + ) + + def cached_forward( # Called for token-by-token generation + self, delta_x: keras.KerasTensor, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None, + ) -> keras.KerasTensor: + + delta_u, delta_q, delta_k, delta_v = hstu_compute_uqvk( + x=delta_x, norm_weight=self._input_norm_weight, norm_bias=self._input_norm_bias, norm_eps=1e-6, + num_heads=self._num_heads, attn_dim=self._attention_dim, hidden_dim=self._hidden_dim, + uvqk_weight=self._uvqk_weight, uvqk_bias=self._uvqk_beta, + ) + + A = self._attention_dim; V = self._hidden_dim; H = self._num_heads + k_flat = ops.reshape(delta_k, [-1, H * A]) + v_flat = ops.reshape(delta_v, [-1, H * V]) + + k_full, v_full, max_seq_len, seq_offsets = self.construct_full_kv(delta_k=k_flat, delta_v=v_flat) + + self.update_kv_cache(max_seq_len=max_seq_len, seq_offsets=seq_offsets, k=k_full, v=v_full, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths) + + # Reshape K and V back to [L_full, H, D] for attention calculation + k = ops.reshape(k_full, [-1, H, A]) + v = ops.reshape(v_full, [-1, H, V]) + + + delta_attn_output = delta_hstu_mha( + max_seq_len=max_seq_len, alpha=self._attn_alpha, delta_q=delta_q, k=k, v=v, seq_offsets=seq_offsets, + num_targets=num_targets if self._target_aware else None, max_attn_len=self._max_attn_len, + contextual_seq_len=self._contextual_seq_len, + ) + + delta_attn_output = ops.reshape(delta_attn_output, [-1, V * H]) + + + return hstu_compute_output( + attn=delta_attn_output, u=delta_u, x=delta_x, norm_weight=self._output_norm_weight, norm_bias=self._output_norm_bias, + norm_eps=1e-6, dropout_ratio=self._output_dropout_ratio, output_weight=self._output_weight, + group_norm=self._use_group_norm, num_heads=self._num_heads, linear_dim=self._hidden_dim, + concat_ux=True, training=training, recompute_y_in_backward=self._recompute_y, + ) + + +class STUStack(layers.Layer): + def __init__(self, stu_layers: List[STULayer], is_inference: bool = False, **kwargs): + super().__init__(**kwargs) + self._stu_layers = stu_layers + + def call( + self, x: keras.KerasTensor, x_lengths: keras.KerasTensor, x_offsets: keras.KerasTensor, max_seq_len: int, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None, + ) -> keras.KerasTensor: + for layer in self._stu_layers: + x = layer(x=x, x_lengths=x_lengths, x_offsets=x_offsets, max_seq_len=max_seq_len, num_targets=num_targets, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths, training=training) + return x + + def cached_forward( + self, delta_x: keras.KerasTensor, num_targets: keras.KerasTensor, max_kv_caching_len: int = 0, kv_caching_lengths: Optional[keras.KerasTensor] = None, training: Optional[bool] = None, + ) -> keras.KerasTensor: + for layer in self._stu_layers: + delta_x = layer.cached_forward(delta_x=delta_x, num_targets=num_targets, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths, training=training) + return delta_x From a5e4627a2a0e7331551bc3b05d74223c77e5bc82 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Mon, 22 Sep 2025 18:23:48 +0530 Subject: [PATCH 2/9] Update keras_rs/src/layers/stu.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/stu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_rs/src/layers/stu.py b/keras_rs/src/layers/stu.py index b9ee8c7..a3b46fe 100644 --- a/keras_rs/src/layers/stu.py +++ b/keras_rs/src/layers/stu.py @@ -132,7 +132,7 @@ def update_kv_cache( self, max_seq_len: int, seq_offsets: keras.KerasTensor, k: Optional[keras.KerasTensor], v: Optional[keras.KerasTensor], max_kv_caching_len: int, kv_caching_lengths: Optional[keras.KerasTensor], ) -> None: # NOTE: Assumes _update_kv_cache is available - self.k_cache, self.v_cache, self.max_kv_caching_len, self.kv_caching_offsets = (_update_kv_cache(max_seq_len=max_seq_len, seq_offsets=seq_offsets, k=k, v=v, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths, orig_k_cache=self.k_cache, orig_v_cache=self.v_cache, orig_max_kv_caching_len=self.max_kv_caching_len, orig_kv_caching_offsets=self.kv_caching_offsets,)) + self.k_cache, self.v_cache, self.max_kv_caching_len, self.kv_caching_offsets = _update_kv_cache(max_seq_len=max_seq_len, seq_offsets=seq_offsets, k=k, v=v, max_kv_caching_len=max_kv_caching_len, kv_caching_lengths=kv_caching_lengths, orig_k_cache=self.k_cache, orig_v_cache=self.v_cache, orig_max_kv_caching_len=self.max_kv_caching_len, orig_kv_caching_offsets=self.kv_caching_offsets) def construct_full_kv(self, delta_k: keras.KerasTensor, delta_v: keras.KerasTensor,) -> Tuple[keras.KerasTensor, keras.KerasTensor, int, keras.KerasTensor]: # NOTE: Assumes _construct_full_kv is available From b823720d7b29f6594e4616193251aa6e89a55b2c Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Mon, 22 Sep 2025 18:24:17 +0530 Subject: [PATCH 3/9] Update keras_rs/src/layers/common.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_rs/src/layers/common.py b/keras_rs/src/layers/common.py index 4e5b9f5..1ebea2c 100644 --- a/keras_rs/src/layers/common.py +++ b/keras_rs/src/layers/common.py @@ -17,7 +17,7 @@ def get_valid_attn_mask_keras( max_attn_len: int = 0, contextual_seq_len: int = 0, min_full_attn_seq_len: int = 0, -): +) -> keras.KerasTensor: """ Keras implementation of the valid attention mask generation, combining causality, sequence lengths, and target awareness. From e9fd20471f4f5c2f6a19480e80c9e7487c673f66 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Mon, 22 Sep 2025 18:24:34 +0530 Subject: [PATCH 4/9] Update keras_rs/src/layers/jagged_tensors.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/jagged_tensors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_rs/src/layers/jagged_tensors.py b/keras_rs/src/layers/jagged_tensors.py index 87e96e6..84e90aa 100644 --- a/keras_rs/src/layers/jagged_tensors.py +++ b/keras_rs/src/layers/jagged_tensors.py @@ -95,7 +95,7 @@ def keras_concat_2D_jagged_jagged(values_left, values_right, max_len_left, max_l mask = ops.logical_or(mask < lengths_left_broadcast, ops.logical_and(mask >= max_len_left, mask < max_len_left + lengths_right_broadcast)) return concatted_dense[ops.reshape(mask, [-1])] - def pytorch_concat_2D_jagged_resolver(values_left, values_right, max_len_left, max_len_right, offsets_left, offsets_right): + def keras_concat_2D_jagged_resolver(values_left, values_right, max_len_left, max_len_right, offsets_left, offsets_right): L_total = ops.shape(values_left)[0] offsets_left_non_optional = offsets_left if offsets_left is None: offsets_left_non_optional = max_len_left * ops.arange(L_total // max_len_left + 1, dtype='int32') From d035515e365aefad3d4b1bccc1b1536791cbb68e Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:54:52 +0530 Subject: [PATCH 5/9] Update keras_rs/src/layers/hstu_mha_attention.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/hstu_mha_attention.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/keras_rs/src/layers/hstu_mha_attention.py b/keras_rs/src/layers/hstu_mha_attention.py index d7b146d..0986e40 100644 --- a/keras_rs/src/layers/hstu_mha_attention.py +++ b/keras_rs/src/layers/hstu_mha_attention.py @@ -14,14 +14,16 @@ def keras_pad_qkv( Assumes keras_jagged_to_padded_dense is available globally. """ L, H, D = ops.shape(q); V_dim = ops.shape(v)[2] - values_q_k = ops.reshape(q, [L, H * D]); values_v = ops.reshape(v, [L, H * V_dim]) + values_q = ops.reshape(q, [L, H * D]); values_k = ops.reshape(k, [L, H * D]); values_v = ops.reshape(v, [L, H * V_dim]) # Pad Q, K, V - padded_q_k = keras_jagged_to_padded_dense(values=values_q_k, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) - padded_v = keras_jagged_to_padded_dense(values=values_v, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) + padded_q = keras_jagged_to_padded_dense(values=values_q, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) + padded_k = keras_jagged_to_padded_dense(values=values_k, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) + padded_v = keras_jagged_to_padded_dense(values=values_v, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) - B = ops.shape(padded_q_k)[0]; padded_q_k = ops.reshape(padded_q_k, [B, N, H, D]); padded_v = ops.reshape(padded_v, [B, N, H, V_dim]) - padded_q = ops.transpose(padded_q_k, [0, 2, 1, 3]); padded_k = ops.transpose(padded_q_k, [0, 2, 1, 3]) + B = ops.shape(padded_q)[0] + padded_q = ops.reshape(padded_q, [B, N, H, D]); padded_k = ops.reshape(padded_k, [B, N, H, D]); padded_v = ops.reshape(padded_v, [B, N, H, V_dim]) + padded_q = ops.transpose(padded_q, [0, 2, 1, 3]); padded_k = ops.transpose(padded_k, [0, 2, 1, 3]) padded_v = ops.transpose(padded_v, [0, 2, 1, 3]) return padded_q, padded_k, padded_v From acd294c2e18af7fa378b07ef1fc42be4d839dfff Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:55:35 +0530 Subject: [PATCH 6/9] Update keras_rs/src/layers/stu.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/stu.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/keras_rs/src/layers/stu.py b/keras_rs/src/layers/stu.py index a3b46fe..79f6f52 100644 --- a/keras_rs/src/layers/stu.py +++ b/keras_rs/src/layers/stu.py @@ -45,8 +45,16 @@ def _update_kv_cache( delta_offsets = seq_offsets - kv_caching_offsets # NOTE: split_2D_jagged is available from jagged_tensors.py - k_cache, _ = split_2D_jagged(max_seq_len=max_seq_len, values=ops.reshape(fx_unwrap_optional_tensor(k), [-1, ops.shape(k)[-1]]), max_len_left=None, max_len_right=None, offsets_left=kv_caching_offsets, offsets_right=delta_offsets) - v_cache, _ = split_2D_jagged(max_seq_len=max_seq_len, values=ops.reshape(fx_unwrap_optional_tensor(v), [-1, ops.shape(v)[-1]]), max_len_left=None, max_len_right=None, offsets_left=kv_caching_offsets, offsets_right=delta_offsets) + if k is not None: + k_values = ops.reshape(k, [ops.shape(k)[0], -1]) + k_cache, _ = split_2D_jagged(max_seq_len=max_seq_len, values=k_values, max_len_left=None, max_len_right=None, offsets_left=kv_caching_offsets, offsets_right=delta_offsets) + else: + k_cache = fx_unwrap_optional_tensor(k) + if v is not None: + v_values = ops.reshape(v, [ops.shape(v)[0], -1]) + v_cache, _ = split_2D_jagged(max_seq_len=max_seq_len, values=v_values, max_len_left=None, max_len_right=None, offsets_left=kv_caching_offsets, offsets_right=delta_offsets) + else: + v_cache = fx_unwrap_optional_tensor(v) if max_kv_caching_len == 0: max_kv_caching_len = ops.convert_to_numpy(ops.cast(ops.max(kv_caching_lengths), dtype="int32")).item() From ddbb2ee07352b2e7424814f49bae7439e8630c22 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:56:04 +0530 Subject: [PATCH 7/9] Update keras_rs/src/layers/stu.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/stu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras_rs/src/layers/stu.py b/keras_rs/src/layers/stu.py index 79f6f52..f395004 100644 --- a/keras_rs/src/layers/stu.py +++ b/keras_rs/src/layers/stu.py @@ -5,7 +5,8 @@ from keras import layers from keras_rs.src.layers.common import fx_unwrap_optional_tensor -from keras_rs.src.layers.hstu_compute_output import hstu_compute_uqvk, hstu_compute_output +from keras_rs.src.layers.hstu_compute_output import hstu_compute_output +from keras_rs.src.layers.hstu_uqvk_output import hstu_compute_uqvk from keras_rs.src.layers.hstu_preprocess_attention import keras_hstu_preprocess_and_attention from keras_rs.src.layers.hstu_mha_attention import delta_hstu_mha from keras_rs.src.layers.jagged_tensors import split_2D_jagged, concat_2D_jagged From 6e9ce4d9a4bb98e3a8b5e78612abc2e6e5a16119 Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:56:53 +0530 Subject: [PATCH 8/9] Update keras_rs/src/layers/jagged_tensors.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/jagged_tensors.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/keras_rs/src/layers/jagged_tensors.py b/keras_rs/src/layers/jagged_tensors.py index 84e90aa..2355187 100644 --- a/keras_rs/src/layers/jagged_tensors.py +++ b/keras_rs/src/layers/jagged_tensors.py @@ -70,7 +70,10 @@ def keras_split_2D_jagged_jagged(max_seq_len, values, offsets_left, offsets_righ def keras_split_2D_jagged_resolver(max_seq_len, values, max_len_left, max_len_right, offsets_left, offsets_right): L_total = ops.shape(values)[0] offsets_left_non_optional = offsets_left - if offsets_left is None: offsets_left_non_optional = max_len_left * ops.arange(L_total // max_len_left + 1, dtype='int32') + if offsets_left is None: + if max_len_left is None: + raise ValueError("Either offsets_left or max_len_left must be provided.") + offsets_left_non_optional = max_len_left * ops.arange(L_total // max_len_left + 1, dtype='int32') offsets_right_non_optional = offsets_right if offsets_right is None: offsets_right_non_optional = max_len_right * ops.arange(L_total // max_len_right + 1, dtype='int32') return keras_split_2D_jagged_jagged(max_seq_len=max_seq_len, values=values, offsets_left=offsets_left_non_optional, offsets_right=offsets_right_non_optional) From 8c64861018e41fbcae7da2945a4988081c4477eb Mon Sep 17 00:00:00 2001 From: LakshmiKalaKadali <149650845+LakshmiKalaKadali@users.noreply.github.com> Date: Thu, 25 Sep 2025 14:57:54 +0530 Subject: [PATCH 9/9] Update keras_rs/src/layers/hstu_compute_output.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras_rs/src/layers/hstu_compute_output.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/keras_rs/src/layers/hstu_compute_output.py b/keras_rs/src/layers/hstu_compute_output.py index 9d35959..d40107c 100644 --- a/keras_rs/src/layers/hstu_compute_output.py +++ b/keras_rs/src/layers/hstu_compute_output.py @@ -31,9 +31,7 @@ def keras_norm_mul_dropout( raise NotImplementedError("Group Norm path not suitable for simple Keras ops conversion.") else: # Functional Layer Normalization (Simulated keras_layer_norm) - mean = ops.mean(x, axis=-1, keepdims=True) - variance = ops.mean(ops.square(x - mean), axis=-1, keepdims=True) - x_norm = (x - mean) / ops.sqrt(variance + eps) + x_norm = ops.layer_norm(x, axis=-1, epsilon=eps) # Apply weight and bias (Gamma * x_norm + Beta) y_norm = x_norm * weight + bias