-
Notifications
You must be signed in to change notification settings - Fork 15
Adding STU layer #154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Adding STU layer #154
Changes from all commits
a94ed20
a5e4627
b823720
e9fd204
d035515
acd294c
ddbb2ee
6e9ce4d
8c64861
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import keras | ||
from keras import ops | ||
from typing import List, Optional, Tuple | ||
|
||
def fx_unwrap_optional_tensor(optional: Optional[keras.KerasTensor]) -> keras.KerasTensor: | ||
"""Helper to unwrap optional tensors, returning a zero-tensor for uninitialized cache.""" | ||
if optional is None: | ||
# Returning a zero-tensor is necessary for graph tracing when the cache is uninitialized. | ||
return ops.zeros((0,), dtype='float32') | ||
return optional | ||
|
||
def get_valid_attn_mask_keras( | ||
causal: bool, | ||
N: int, | ||
seq_lengths: keras.KerasTensor, | ||
num_targets: Optional[keras.KerasTensor] = None, | ||
max_attn_len: int = 0, | ||
contextual_seq_len: int = 0, | ||
min_full_attn_seq_len: int = 0, | ||
) -> keras.KerasTensor: | ||
""" | ||
Keras implementation of the valid attention mask generation, combining | ||
causality, sequence lengths, and target awareness. | ||
""" | ||
ids = ops.reshape(ops.arange(0, N, dtype="int32"), (1, N)) | ||
max_ids = ops.reshape(seq_lengths, (-1, 1, 1)) | ||
B = ops.shape(seq_lengths)[0] | ||
|
||
if contextual_seq_len > 0: | ||
ids = ids - contextual_seq_len + 1 | ||
ids = ops.maximum(ids, 0) | ||
max_ids = max_ids - contextual_seq_len + 1 | ||
|
||
if num_targets is not None: | ||
max_ids = max_ids - ops.reshape(num_targets, (-1, 1, 1)) | ||
ids = ops.minimum(ids, max_ids) | ||
row_ids = ops.broadcast_to(ops.reshape(ids, (-1, N, 1)), (B, N, N)) | ||
col_ids = ops.broadcast_to(ops.reshape(ids, (-1, 1, N)), (B, N, N)) | ||
else: | ||
row_ids = ops.broadcast_to(ops.reshape(ids, (N, 1)), (N, N)) | ||
col_ids = ops.transpose(row_ids) | ||
row_ids = ops.reshape(row_ids, (1, N, N)) | ||
col_ids = ops.reshape(col_ids, (1, N, N)) | ||
max_ids = None | ||
Comment on lines
+40
to
+44
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic in this else:
row_ids = ops.broadcast_to(ops.reshape(ids, (1, N, 1)), (B, N, N))
col_ids = ops.broadcast_to(ops.reshape(ids, (1, 1, N)), (B, N, N))
max_ids = None |
||
|
||
row_col_dist = row_ids - col_ids | ||
valid_attn_mask = ops.reshape(ops.eye(N, dtype="bool"), (1, N, N)) | ||
|
||
if not causal: | ||
row_col_dist = ops.where(row_col_dist > 0, row_col_dist, -row_col_dist) | ||
|
||
valid_attn_mask = ops.logical_or(valid_attn_mask, row_col_dist > 0) | ||
|
||
if max_attn_len > 0: | ||
valid_attn_mask = ops.logical_and(valid_attn_mask, row_col_dist <= max_attn_len) | ||
|
||
if contextual_seq_len > 0 and max_ids is not None: | ||
valid_attn_mask = ops.logical_or( | ||
valid_attn_mask, ops.logical_and(row_ids == 0, col_ids < max_ids) | ||
) | ||
|
||
return valid_attn_mask |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import keras | ||
from keras import ops | ||
from typing import List, Optional, Tuple | ||
|
||
def keras_norm_mul_dropout( | ||
x: keras.KerasTensor, | ||
u: keras.KerasTensor, | ||
weight: keras.KerasTensor, | ||
bias: keras.KerasTensor, | ||
eps: float, | ||
dropout_ratio: float, | ||
training: bool, | ||
silu_u: bool = False, | ||
concat_ux: bool = False, | ||
group_norm: bool = False, | ||
num_heads: int = 1, | ||
linear_dim: int = -1, | ||
Comment on lines
+16
to
+17
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
) -> keras.KerasTensor: | ||
""" | ||
Keras 3 equivalent of pytorch_norm_mul_dropout. | ||
Applies normalization, element-wise multiplication with u, and dropout. | ||
Assumes keras_layer_norm is available (though the logic is inlined here). | ||
""" | ||
x = ops.convert_to_tensor(x, dtype='float32') | ||
u = ops.convert_to_tensor(u, dtype='float32') | ||
LakshmiKalaKadali marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if silu_u: | ||
u = ops.silu(u) | ||
|
||
if group_norm: | ||
raise NotImplementedError("Group Norm path not suitable for simple Keras ops conversion.") | ||
else: | ||
# Functional Layer Normalization (Simulated keras_layer_norm) | ||
x_norm = ops.layer_norm(x, axis=-1, epsilon=eps) | ||
|
||
# Apply weight and bias (Gamma * x_norm + Beta) | ||
y_norm = x_norm * weight + bias | ||
|
||
# Apply u multiplication (Element-wise gating) | ||
y = u * y_norm | ||
|
||
if concat_ux: | ||
y = ops.concatenate([u, x, y], axis=1) | ||
|
||
# Dropout (using Keras layer for correct training=True/False behavior) | ||
y = keras.layers.Dropout(dropout_ratio)(y, training=training) | ||
|
||
return ops.cast(y, dtype=x.dtype) | ||
|
||
def keras_hstu_compute_output( | ||
attn: keras.KerasTensor, | ||
u: keras.KerasTensor, | ||
x: keras.KerasTensor, | ||
norm_weight: keras.KerasTensor, | ||
norm_bias: keras.KerasTensor, | ||
output_weight: keras.KerasTensor, | ||
eps: float, | ||
dropout_ratio: float, | ||
training: bool, | ||
silu_u: bool = False, | ||
concat_ux: bool = False, | ||
group_norm: bool = False, | ||
num_heads: int = 1, | ||
linear_dim: int = -1, | ||
) -> keras.KerasTensor: | ||
""" | ||
Core kernel for the final residual block calculation (Attn Output -> Norm/Dropout -> MatMul -> Residual Add). | ||
""" | ||
y = keras_norm_mul_dropout( | ||
x=attn, | ||
u=u, | ||
weight=norm_weight, | ||
bias=norm_bias, | ||
eps=eps, | ||
dropout_ratio=dropout_ratio, | ||
training=training, | ||
silu_u=silu_u, | ||
concat_ux=concat_ux, | ||
group_norm=group_norm, | ||
num_heads=num_heads, | ||
linear_dim=linear_dim, | ||
) | ||
|
||
# Final output: Residual addition of input (x) and transformed attention output (y @ output_weight) | ||
output = ops.add(x, ops.matmul(y, output_weight)) | ||
|
||
return output | ||
|
||
def hstu_compute_output( | ||
attn: keras.KerasTensor, | ||
u: keras.KerasTensor, | ||
x: keras.KerasTensor, | ||
norm_weight: keras.KerasTensor, | ||
norm_bias: keras.KerasTensor, | ||
norm_eps: float, | ||
output_weight: keras.KerasTensor, | ||
num_heads: int, | ||
linear_dim: int, | ||
dropout_ratio: float, | ||
training: bool, | ||
concat_ux: bool, | ||
group_norm: bool, | ||
recompute_y_in_backward: bool, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
) -> keras.KerasTensor: | ||
""" | ||
Top-level wrapper for the output computation, delegates to the core Keras kernel. | ||
""" | ||
return keras_hstu_compute_output( | ||
attn=attn, | ||
u=u, | ||
x=x, | ||
norm_weight=norm_weight, | ||
norm_bias=norm_bias, | ||
output_weight=output_weight, | ||
eps=norm_eps, | ||
dropout_ratio=dropout_ratio, | ||
training=training, | ||
silu_u=False, | ||
concat_ux=concat_ux, | ||
group_norm=group_norm, | ||
num_heads=num_heads, | ||
linear_dim=linear_dim, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import keras | ||
from keras import ops | ||
from typing import Tuple, Optional | ||
from keras import layers | ||
|
||
# --- Assumed Imports --- | ||
# Assumes keras_jagged_to_padded_dense, keras_dense_to_jagged, and get_valid_attn_mask_keras are available from other modules. | ||
|
||
def keras_pad_qkv( | ||
q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, N: int, | ||
) -> Tuple[keras.KerasTensor, keras.KerasTensor, keras.KerasTensor]: | ||
""" | ||
Helper to pad Q, K, V from jagged to dense format for MHA. | ||
Assumes keras_jagged_to_padded_dense is available globally. | ||
""" | ||
L, H, D = ops.shape(q); V_dim = ops.shape(v)[2] | ||
values_q = ops.reshape(q, [L, H * D]); values_k = ops.reshape(k, [L, H * D]); values_v = ops.reshape(v, [L, H * V_dim]) | ||
|
||
# Pad Q, K, V | ||
padded_q = keras_jagged_to_padded_dense(values=values_q, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) | ||
padded_k = keras_jagged_to_padded_dense(values=values_k, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) | ||
padded_v = keras_jagged_to_padded_dense(values=values_v, offsets=[seq_offsets], max_lengths=[N], padding_value=0.0) | ||
|
||
B = ops.shape(padded_q)[0] | ||
padded_q = ops.reshape(padded_q, [B, N, H, D]); padded_k = ops.reshape(padded_k, [B, N, H, D]); padded_v = ops.reshape(padded_v, [B, N, H, V_dim]) | ||
padded_q = ops.transpose(padded_q, [0, 2, 1, 3]); padded_k = ops.transpose(padded_k, [0, 2, 1, 3]) | ||
padded_v = ops.transpose(padded_v, [0, 2, 1, 3]) | ||
return padded_q, padded_k, padded_v | ||
|
||
|
||
def keras_hstu_mha( | ||
max_seq_len: int, alpha: float, q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, causal: bool = True, dropout_pr: float = 0.0, training: bool = True, attn_scale: Optional[keras.KerasTensor] = None, **kwargs | ||
) -> keras.KerasTensor: | ||
"""Core Keras implementation of the full Multi-Head Attention kernel (Non-Cached).""" | ||
L, H, _ = ops.shape(q); V_dim = ops.shape(v)[2] | ||
q, k, v = keras_pad_qkv(q, k, v, seq_offsets, max_seq_len) | ||
qk_attn = ops.einsum("bhxa,bhya->bhxy", q, k) * alpha | ||
|
||
# Activation and Scaling | ||
if attn_scale is not None: | ||
if ops.ndim(attn_scale) > 0: | ||
attn_scale_padded = keras_jagged_to_padded_dense(values=ops.expand_dims(attn_scale, axis=-1), offsets=[seq_offsets], max_lengths=[max_seq_len], padding_value=0.0) | ||
attn_scale_padded = ops.expand_dims(ops.cast(attn_scale_padded, qk_attn.dtype), axis=1) | ||
qk_attn = ops.silu(qk_attn) * attn_scale_padded | ||
else: | ||
qk_attn = ops.silu(qk_attn) / max_seq_len | ||
|
||
# Masking | ||
seq_lengths = seq_offsets[1:] - seq_offsets[:-1] | ||
valid_attn_mask = get_valid_attn_mask_keras(causal=causal, N=max_seq_len, seq_lengths=seq_lengths, **kwargs) | ||
qk_attn = qk_attn * ops.expand_dims(ops.cast(valid_attn_mask, qk_attn.dtype), axis=1) | ||
|
||
# Dropout | ||
if dropout_pr > 0.0 and training: | ||
qk_attn = keras.layers.Dropout(dropout_pr)(qk_attn, training=training) | ||
|
||
# Output (Weighted Sum) | ||
attn_dense = ops.einsum("bhxd,bhdv->bhxv", qk_attn, v) | ||
flat_attn_dense = ops.reshape(ops.transpose(attn_dense, [0, 2, 1, 3]), [-1, max_seq_len, H * V_dim]) | ||
|
||
# Convert back to jagged | ||
jagged_output = keras_dense_to_jagged(flat_attn_dense, [seq_offsets]) | ||
L_out = ops.shape(jagged_output)[0] | ||
return ops.reshape(jagged_output, [L_out, H, V_dim]) | ||
|
||
|
||
def keras_cached_hstu_mha( | ||
max_seq_len: int, alpha: float, delta_q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, num_targets: Optional[keras.KerasTensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, enable_tma: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
) -> keras.KerasTensor: | ||
"""Core Keras implementation of the cached attention kernel (Delta Q attends to Full K/V).""" | ||
L_delta, H, D = ops.shape(delta_q); B = ops.shape(seq_offsets)[0] - 1; DeltaSize = L_delta // B; V_dim = ops.shape(v)[2] | ||
|
||
# 1. Reshape Delta Q | ||
delta_q = ops.transpose(ops.reshape(delta_q, (B, DeltaSize, H, D)), perm=[0, 2, 1, 3]) | ||
|
||
# 2. Reshape Full K and V (Inputs k, v are already flattened/jagged-like) | ||
N_full = max_seq_len | ||
k_full = ops.transpose(ops.reshape(k, (B, N_full, H, D)), [0, 2, 1, 3]) | ||
v_full = ops.transpose(ops.reshape(v, (B, N_full, H, V_dim)), [0, 2, 1, 3]) | ||
|
||
# 3. Attention Score and Activation | ||
qk_attn = ops.einsum("bhxa,bhya->bhxy", delta_q, k_full) * alpha | ||
qk_attn = ops.silu(qk_attn) / max_seq_len | ||
|
||
# 4. Masking (Slice the mask to select only the rows corresponding to the new queries) | ||
seq_lengths = seq_offsets[1:] - seq_offsets[:-1] | ||
full_valid_attn_mask = get_valid_attn_mask_keras(causal=True, N=max_seq_len, seq_lengths=seq_lengths, num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len) | ||
valid_attn_mask_sliced = full_valid_attn_mask[:, -DeltaSize:, :] | ||
|
||
qk_attn = qk_attn * ops.expand_dims(ops.cast(valid_attn_mask_sliced, qk_attn.dtype), axis=1) | ||
|
||
# 5. Output (Weighted Sum) | ||
attn_output = ops.einsum("bhxd,bhdv->bhxv", qk_attn, v_full) | ||
|
||
# 6. Reshape and return [L_delta, H, V_dim] | ||
attn_output = ops.transpose(attn_output, perm=[0, 2, 1, 3]) | ||
return ops.reshape(attn_output, (-1, H, V_dim)) | ||
|
||
|
||
def delta_hstu_mha( | ||
max_seq_len: int, alpha: float, delta_q: keras.KerasTensor, k: keras.KerasTensor, v: keras.KerasTensor, seq_offsets: keras.KerasTensor, num_targets: Optional[keras.KerasTensor] = None, max_attn_len: int = 0, contextual_seq_len: int = 0, kernel=None, enable_tma: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
) -> keras.KerasTensor: | ||
"""Top-level wrapper for cached inference MHA (delegates to core cached kernel).""" | ||
|
||
L_delta, H, D = ops.shape(delta_q) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
# Assertions are maintained by the layer/framework where possible. | ||
|
||
return keras_cached_hstu_mha( | ||
max_seq_len=max_seq_len, alpha=alpha, delta_q=delta_q, k=k, v=v, seq_offsets=seq_offsets, | ||
num_targets=num_targets, max_attn_len=max_attn_len, contextual_seq_len=contextual_seq_len, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import keras | ||
from keras import ops | ||
from typing import Tuple, List, Optional | ||
|
||
|
||
def keras_hstu_preprocess_and_attention( | ||
x: keras.KerasTensor, norm_weight: keras.KerasTensor, norm_bias: keras.KerasTensor, norm_eps: float, num_heads: int, attn_dim: int, hidden_dim: int, | ||
uvqk_weight: keras.KerasTensor, uvqk_bias: keras.KerasTensor, max_seq_len: int, seq_offsets: keras.KerasTensor, attn_alpha: float, causal: bool, | ||
num_targets: Optional[keras.KerasTensor], max_attn_len: int, contextual_seq_len: int, recompute_uvqk_in_backward: bool, | ||
recompute_normed_x_in_backward: bool, sort_by_length: bool, prefill: bool = False, | ||
kernel=None, **kwargs | ||
Comment on lines
+9
to
+11
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
) -> Tuple: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
""" | ||
Keras 3 implementation of the H-STU preprocess and attention workflow. | ||
Orchestrates the conversion of input X into U, Q, K, V and subsequent MHA computation. | ||
""" | ||
|
||
# --- Assertions (Skipped internal torch asserts, simplified to Keras asserts for context) --- | ||
assert max_seq_len > 0, "max_seq_len must be larger than 0" | ||
assert ops.ndim(x) == 2, "x must be 2-D" | ||
assert causal is True, "only causal attention is supported." | ||
|
||
# 1. Compute U, Q, K, V | ||
# Note: hstu_compute_uqvk handles the initial Norm, Linear Projection, and Split. | ||
u, q, k, v = hstu_compute_uqvk( | ||
x=x, norm_weight=norm_weight, norm_bias=norm_bias, norm_eps=norm_eps, | ||
num_heads=num_heads, attn_dim=attn_dim, hidden_dim=hidden_dim, | ||
uvqk_weight=uvqk_weight, uvqk_bias=uvqk_bias, kernel=kernel, | ||
) | ||
|
||
# 2. Compute Attention | ||
attn_output = keras_hstu_mha( | ||
max_seq_len=max_seq_len, alpha=attn_alpha, q=q, k=k, v=v, | ||
seq_offsets=seq_offsets, causal=causal, dropout_pr=0.0, | ||
training=False, num_targets=num_targets, max_attn_len=max_attn_len, | ||
contextual_seq_len=contextual_seq_len, sort_by_length=sort_by_length, | ||
kernel=kernel, **kwargs | ||
) | ||
|
||
# Reshape: [L, H, D] -> [L, H * D] (Flattening for the final hstu_compute_output block) | ||
attn_output = ops.reshape(attn_output, [-1, hidden_dim * num_heads]) | ||
|
||
# Returns u (gating), attention output, k, and v (for caching) | ||
return u, attn_output, k, v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter
min_full_attn_seq_len
is defined but never used within the function. It should be removed to improve code clarity.