Skip to content
62 changes: 62 additions & 0 deletions keras_rs/src/layers/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import keras
from keras import ops
from typing import List, Optional, Tuple

def fx_unwrap_optional_tensor(optional: Optional[keras.KerasTensor]) -> keras.KerasTensor:
"""Helper to unwrap optional tensors, returning a zero-tensor for uninitialized cache."""
if optional is None:
# Returning a zero-tensor is necessary for graph tracing when the cache is uninitialized.
return ops.zeros((0,), dtype='float32')
return optional

def get_valid_attn_mask_keras(
causal: bool,
N: int,
seq_lengths: keras.KerasTensor,
num_targets: Optional[keras.KerasTensor] = None,
max_attn_len: int = 0,
contextual_seq_len: int = 0,
min_full_attn_seq_len: int = 0,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter min_full_attn_seq_len is defined but never used within the function. It should be removed to improve code clarity.

) -> keras.KerasTensor:
"""
Keras implementation of the valid attention mask generation, combining
causality, sequence lengths, and target awareness.
"""
ids = ops.reshape(ops.arange(0, N, dtype="int32"), (1, N))
max_ids = ops.reshape(seq_lengths, (-1, 1, 1))
B = ops.shape(seq_lengths)[0]

if contextual_seq_len > 0:
ids = ids - contextual_seq_len + 1
ids = ops.maximum(ids, 0)
max_ids = max_ids - contextual_seq_len + 1

if num_targets is not None:
max_ids = max_ids - ops.reshape(num_targets, (-1, 1, 1))
ids = ops.minimum(ids, max_ids)
row_ids = ops.broadcast_to(ops.reshape(ids, (-1, N, 1)), (B, N, N))
col_ids = ops.broadcast_to(ops.reshape(ids, (-1, 1, N)), (B, N, N))
else:
row_ids = ops.broadcast_to(ops.reshape(ids, (N, 1)), (N, N))
col_ids = ops.transpose(row_ids)
row_ids = ops.reshape(row_ids, (1, N, N))
col_ids = ops.reshape(col_ids, (1, N, N))
max_ids = None
Comment on lines +40 to +44

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic in this else block does not correctly handle batches. It creates row_ids and col_ids with a shape of (1, N, N), ignoring the batch size B calculated earlier. If the batch size is greater than 1, this will lead to incorrect masking due to broadcasting. The mask should have a shape of (B, N, N) to be consistent with the if num_targets is not None branch.

    else:
        row_ids = ops.broadcast_to(ops.reshape(ids, (1, N, 1)), (B, N, N))
        col_ids = ops.broadcast_to(ops.reshape(ids, (1, 1, N)), (B, N, N))
        max_ids = None


row_col_dist = row_ids - col_ids
valid_attn_mask = ops.reshape(ops.eye(N, dtype="bool"), (1, N, N))

if not causal:
row_col_dist = ops.where(row_col_dist > 0, row_col_dist, -row_col_dist)

valid_attn_mask = ops.logical_or(valid_attn_mask, row_col_dist > 0)

if max_attn_len > 0:
valid_attn_mask = ops.logical_and(valid_attn_mask, row_col_dist <= max_attn_len)

if contextual_seq_len > 0 and max_ids is not None:
valid_attn_mask = ops.logical_or(
valid_attn_mask, ops.logical_and(row_ids == 0, col_ids < max_ids)
)

return valid_attn_mask
123 changes: 123 additions & 0 deletions keras_rs/src/layers/hstu_compute_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import keras
from keras import ops
from typing import List, Optional, Tuple

def keras_norm_mul_dropout(
x: keras.KerasTensor,
u: keras.KerasTensor,
weight: keras.KerasTensor,
bias: keras.KerasTensor,
eps: float,
dropout_ratio: float,
training: bool,
silu_u: bool = False,
concat_ux: bool = False,
group_norm: bool = False,
num_heads: int = 1,
linear_dim: int = -1,
Comment on lines +16 to +17

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameters num_heads and linear_dim are unused in this function and should be removed.

) -> keras.KerasTensor:
"""
Keras 3 equivalent of pytorch_norm_mul_dropout.
Applies normalization, element-wise multiplication with u, and dropout.
Assumes keras_layer_norm is available (though the logic is inlined here).
"""
x = ops.convert_to_tensor(x, dtype='float32')
u = ops.convert_to_tensor(u, dtype='float32')

if silu_u:
u = ops.silu(u)

if group_norm:
raise NotImplementedError("Group Norm path not suitable for simple Keras ops conversion.")
else:
# Functional Layer Normalization (Simulated keras_layer_norm)
x_norm = ops.layer_norm(x, axis=-1, epsilon=eps)

# Apply weight and bias (Gamma * x_norm + Beta)
y_norm = x_norm * weight + bias

# Apply u multiplication (Element-wise gating)
y = u * y_norm

if concat_ux:
y = ops.concatenate([u, x, y], axis=1)

# Dropout (using Keras layer for correct training=True/False behavior)
y = keras.layers.Dropout(dropout_ratio)(y, training=training)

return ops.cast(y, dtype=x.dtype)

def keras_hstu_compute_output(
attn: keras.KerasTensor,
u: keras.KerasTensor,
x: keras.KerasTensor,
norm_weight: keras.KerasTensor,
norm_bias: keras.KerasTensor,
output_weight: keras.KerasTensor,
eps: float,
dropout_ratio: float,
training: bool,
silu_u: bool = False,
concat_ux: bool = False,
group_norm: bool = False,
num_heads: int = 1,
linear_dim: int = -1,
) -> keras.KerasTensor:
"""
Core kernel for the final residual block calculation (Attn Output -> Norm/Dropout -> MatMul -> Residual Add).
"""
y = keras_norm_mul_dropout(
x=attn,
u=u,
weight=norm_weight,
bias=norm_bias,
eps=eps,
dropout_ratio=dropout_ratio,
training=training,
silu_u=silu_u,
concat_ux=concat_ux,
group_norm=group_norm,
num_heads=num_heads,
linear_dim=linear_dim,
)

# Final output: Residual addition of input (x) and transformed attention output (y @ output_weight)
output = ops.add(x, ops.matmul(y, output_weight))

return output

def hstu_compute_output(
attn: keras.KerasTensor,
u: keras.KerasTensor,
x: keras.KerasTensor,
norm_weight: keras.KerasTensor,
norm_bias: keras.KerasTensor,
norm_eps: float,
output_weight: keras.KerasTensor,
num_heads: int,
linear_dim: int,
dropout_ratio: float,
training: bool,
concat_ux: bool,
group_norm: bool,
recompute_y_in_backward: bool,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter recompute_y_in_backward is unused in this function and should be removed.

) -> keras.KerasTensor:
"""
Top-level wrapper for the output computation, delegates to the core Keras kernel.
"""
return keras_hstu_compute_output(
attn=attn,
u=u,
x=x,
norm_weight=norm_weight,
norm_bias=norm_bias,
output_weight=output_weight,
eps=norm_eps,
dropout_ratio=dropout_ratio,
training=training,
silu_u=False,
concat_ux=concat_ux,
group_norm=group_norm,
num_heads=num_heads,
linear_dim=linear_dim,
)
111 changes: 111 additions & 0 deletions keras_rs/src/layers/hstu_mha_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import keras
from keras import ops
from typing import Tuple, Optional
from keras import layers

# --- Assumed Imports ---
# Assumes keras_jagged_to_padded_dense, keras_dense_to_jagged, and get_valid_attn_mask_keras are available from other modules.

def keras_pad_qkv(
q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, N: int,
) -> Tuple[keras.KerasTensor, keras.KerasTensor, keras.KerasTensor]:
"""
Helper to pad Q, K, V from jagged to dense format for MHA.
Assumes keras_jagged_to_padded_dense is available globally.
"""
L, H, D = ops.shape(q); V_dim = ops.shape(v)[2]
values_q = ops.reshape(q, [L, H * D]); values_k = ops.reshape(k, [L, H * D]); values_v = ops.reshape(v, [L, H * V_dim])

# Pad Q, K, V
padded_q = keras_jagged_to_padded_dense(values=values_q, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0)
padded_k = keras_jagged_to_padded_dense(values=values_k, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0)
padded_v = keras_jagged_to_padded_dense(values=values_v, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0)

B = ops.shape(padded_q)[0]
padded_q = ops.reshape(padded_q, [B, N, H, D]); padded_k = ops.reshape(padded_k, [B, N, H, D]); padded_v = ops.reshape(padded_v, [B, N, H, V_dim])
padded_q = ops.transpose(padded_q, [0, 2, 1, 3]); padded_k = ops.transpose(padded_k, [0, 2, 1, 3])
padded_v = ops.transpose(padded_v, [0, 2, 1, 3])
return padded_q, padded_k, padded_v


def keras_hstu_mha(
max_seq_len: int, alpha: float, q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, causal: bool = True, dropout_pr: float = 0.0, training: bool = True, attn_scale: Optional[keras.KerasTensor] = None, **kwargs
) -> keras.KerasTensor:
"""Core Keras implementation of the full Multi-Head Attention kernel (Non-Cached)."""
L, H, _ = ops.shape(q); V_dim = ops.shape(v)[2]
q, k, v = keras_pad_qkv(q, k, v, seq_offsets, max_seq_len)
qk_attn = ops.einsum("bhxa,bhya->bhxy", q, k) * alpha

# Activation and Scaling
if attn_scale is not None:
if ops.ndim(attn_scale) > 0:
attn_scale_padded = keras_jagged_to_padded_dense(values=ops.expand_dims(attn_scale, axis=-1), offsets=[seq_offsets], max_lengths=[max_seq_len], padding_value=0.0)
attn_scale_padded = ops.expand_dims(ops.cast(attn_scale_padded, qk_attn.dtype), axis=1)
qk_attn = ops.silu(qk_attn) * attn_scale_padded
else:
qk_attn = ops.silu(qk_attn) / max_seq_len

# Masking
seq_lengths = seq_offsets[1:] - seq_offsets[:-1]
valid_attn_mask = get_valid_attn_mask_keras(causal=causal, N=max_seq_len, seq_lengths=seq_lengths, **kwargs)
qk_attn = qk_attn * ops.expand_dims(ops.cast(valid_attn_mask, qk_attn.dtype), axis=1)

# Dropout
if dropout_pr > 0.0 and training:
qk_attn = keras.layers.Dropout(dropout_pr)(qk_attn, training=training)

# Output (Weighted Sum)
attn_dense = ops.einsum("bhxd,bhdv->bhxv", qk_attn, v)
flat_attn_dense = ops.reshape(ops.transpose(attn_dense, [0, 2, 1, 3]), [-1, max_seq_len, H * V_dim])

# Convert back to jagged
jagged_output = keras_dense_to_jagged(flat_attn_dense, [seq_offsets])
L_out = ops.shape(jagged_output)[0]
return ops.reshape(jagged_output, [L_out, H, V_dim])


def keras_cached_hstu_mha(
max_seq_len: int, alpha: float, delta_q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, num_targets: Optional[keras.KerasTensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, enable_tma: bool = False,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter enable_tma is unused in this function and should be removed.

) -> keras.KerasTensor:
"""Core Keras implementation of the cached attention kernel (Delta Q attends to Full K/V)."""
L_delta, H, D = ops.shape(delta_q); B = ops.shape(seq_offsets)[0] - 1; DeltaSize = L_delta // B; V_dim = ops.shape(v)[2]

# 1. Reshape Delta Q
delta_q = ops.transpose(ops.reshape(delta_q, (B, DeltaSize, H, D)), perm=[0, 2, 1, 3])

# 2. Reshape Full K and V (Inputs k, v are already flattened/jagged-like)
N_full = max_seq_len
k_full = ops.transpose(ops.reshape(k, (B, N_full, H, D)), [0, 2, 1, 3])
v_full = ops.transpose(ops.reshape(v, (B, N_full, H, V_dim)), [0, 2, 1, 3])

# 3. Attention Score and Activation
qk_attn = ops.einsum("bhxa,bhya->bhxy", delta_q, k_full) * alpha
qk_attn = ops.silu(qk_attn) / max_seq_len

# 4. Masking (Slice the mask to select only the rows corresponding to the new queries)
seq_lengths = seq_offsets[1:] - seq_offsets[:-1]
full_valid_attn_mask = get_valid_attn_mask_keras(causal=True, N=max_seq_len, seq_lengths=seq_lengths, num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len)
valid_attn_mask_sliced = full_valid_attn_mask[:, -DeltaSize:, :]

qk_attn = qk_attn * ops.expand_dims(ops.cast(valid_attn_mask_sliced, qk_attn.dtype), axis=1)

# 5. Output (Weighted Sum)
attn_output = ops.einsum("bhxd,bhdv->bhxv", qk_attn, v_full)

# 6. Reshape and return [L_delta, H, V_dim]
attn_output = ops.transpose(attn_output, perm=[0, 2, 1, 3])
return ops.reshape(attn_output, (-1, H, V_dim))


def delta_hstu_mha(
max_seq_len: int, alpha: float, delta_q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, num_targets: Optional[keras.KerasTensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, kernel=None, enable_tma: bool = False,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameters kernel and enable_tma are unused in this function and should be removed.

) -> keras.KerasTensor:
"""Top-level wrapper for cached inference MHA (delegates to core cached kernel)."""

L_delta, H, D = ops.shape(delta_q)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The local variables L_delta, H, and D are defined but never used. They should be removed.

Suggested change
L_delta, H, D = ops.shape(delta_q)
# L_delta, H, D = ops.shape(delta_q)

# Assertions are maintained by the layer/framework where possible.

return keras_cached_hstu_mha(
max_seq_len=max_seq_len, alpha=alpha, delta_q=delta_q, k=k, v=v, seq_offsets=seq_offsets,
num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len,
)
44 changes: 44 additions & 0 deletions keras_rs/src/layers/hstu_preprocess_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import keras
from keras import ops
from typing import Tuple, List, Optional


def keras_hstu_preprocess_and_attention(
x: keras.KerasTensor, norm_weight: keras.KerasTensor, norm_bias: keras.KerasTensor, norm_eps: float, num_heads: int, attn_dim: int, hidden_dim: int,
uvqk_weight: keras.KerasTensor, uvqk_bias: keras.KerasTensor, max_seq_len: int, seq_offsets: keras.KerasTensor, attn_alpha: float, causal: bool,
num_targets: Optional[keras.KerasTensor], max_attn_len: int, contextual_seq_len: int, recompute_uvqk_in_backward: bool,
recompute_normed_x_in_backward: bool, sort_by_length: bool, prefill: bool = False,
kernel=None, **kwargs
Comment on lines +9 to +11

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameters recompute_uvqk_in_backward, recompute_normed_x_in_backward, sort_by_length, and kernel are unused in this function and its callees. They appear to be remnants from a PyTorch implementation and should be removed.

) -> Tuple:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return type hint Tuple is not specific. For better readability and type checking, please specify the types of the elements in the tuple, for example: -> Tuple[keras.KerasTensor, keras.KerasTensor, keras.KerasTensor, keras.KerasTensor].

"""
Keras 3 implementation of the H-STU preprocess and attention workflow.
Orchestrates the conversion of input X into U, Q, K, V and subsequent MHA computation.
"""

# --- Assertions (Skipped internal torch asserts, simplified to Keras asserts for context) ---
assert max_seq_len > 0, "max_seq_len must be larger than 0"
assert ops.ndim(x) == 2, "x must be 2-D"
assert causal is True, "only causal attention is supported."

# 1. Compute U, Q, K, V
# Note: hstu_compute_uqvk handles the initial Norm, Linear Projection, and Split.
u, q, k, v = hstu_compute_uqvk(
x=x, norm_weight=norm_weight, norm_bias=norm_bias, norm_eps=norm_eps,
num_heads=num_heads, attn_dim=attn_dim, hidden_dim=hidden_dim,
uvqk_weight=uvqk_weight, uvqk_bias=uvqk_bias, kernel=kernel,
)

# 2. Compute Attention
attn_output = keras_hstu_mha(
max_seq_len=max_seq_len, alpha=attn_alpha, q=q, k=k, v=v,
seq_offsets=seq_offsets, causal=causal, dropout_pr=0.0,
training=False, num_targets=num_targets, max_attn_len=max_attn_len,
contextual_seq_len=contextual_seq_len, sort_by_length=sort_by_length,
kernel=kernel, **kwargs
)

# Reshape: [L, H, D] -> [L, H * D] (Flattening for the final hstu_compute_output block)
attn_output = ops.reshape(attn_output, [-1, hidden_dim * num_heads])

# Returns u (gating), attention output, k, and v (for caching)
return u, attn_output, k, v
Loading