Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 50 additions & 30 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def __call__(self, mel):


class EncoderBlock(nn.Module):
"""Pre-norm self-attention + pre-norm FFN with gated residual connections."""
"""Layer input is read from accumulated states, then transformed by attention/FFN."""
num_heads: int
num_kv_heads: int
d_model: int
Expand All @@ -206,27 +206,34 @@ class EncoderBlock(nn.Module):
no_feedforward: bool = True

@nn.compact
def __call__(self, x, mask=None, rope=None, ffn_mask=None, deterministic=True):
gate = nn.sigmoid(self.param("attn_gate", jinit.zeros, ())).astype(self.dtype)
residual = x
def __call__(self, x_accum, layer_idx, mask=None, rope=None, ffn_mask=None, deterministic=True):
q = self.param("residual_query", default_init(), (self.d_model,)).astype(jnp.float32)
k = x_accum.astype(jnp.float32)
v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_accum)

logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model))
valid_layers = (jnp.arange(x_accum.shape[2]) <= layer_idx)[None, None, :]
logits = jnp.where(valid_layers, logits, jnp.finfo(logits.dtype).min)
Comment on lines +211 to +216

Copilot AI Mar 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This residual-attention computation reads the full x_accum tensor (B×T×(L+1)×D) and computes logits/values over all layers each iteration, making the encoder block’s cost scale ~O(L^2) in number of layers and increasing per-step memory bandwidth substantially. If this is not strictly required, consider restricting the aggregation to a bounded window of previous layers or maintaining a smaller running summary so each layer doesn’t repeatedly process the entire accumulated state.

Suggested change
k = x_accum.astype(jnp.float32)
v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_accum)
logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model))
valid_layers = (jnp.arange(x_accum.shape[2]) <= layer_idx)[None, None, :]
logits = jnp.where(valid_layers, logits, jnp.finfo(logits.dtype).min)
# Only use accumulated states up to and including the current layer to avoid
# repeatedly processing future layers that are always masked out.
x_prefix = x_accum[:, :, : layer_idx + 1, :]
k = x_prefix.astype(jnp.float32)
v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_prefix)
logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model))

Copilot uses AI. Check for mistakes.
alphas = nn.sigmoid(logits).astype(self.dtype)
x = jnp.einsum("btl,btld->btd", alphas, v)

x = ZCRMSNorm(dtype=self.dtype)(x)
x = MultiHeadAttention(self.num_heads, self.num_kv_heads, self.d_model, self.num_layers, self.dtype, name="self_attn")(
x, x, mask=mask, rope=rope
)
x = residual + gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)

if not self.no_feedforward:
ffn_gate = nn.sigmoid(self.param("ffn_gate", jinit.zeros, ())).astype(self.dtype)
residual = x
x = ZCRMSNorm(dtype=self.dtype)(x)
x = FeedForward(self.d_model, self.d_ff, self.num_layers, self.dtype, self.activation)(x, ffn_mask=ffn_mask)
x = residual + ffn_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
x = ffn_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)

return x


class _EncoderScanBody(nn.Module):
"""Wraps EncoderBlock for nn.scan: carry = (x, mask, rope, ffn_mask)."""
"""Wraps EncoderBlock for nn.scan: carry = (mask, rope, ffn_mask, x_accum)."""
num_heads: int
num_kv_heads: int
d_model: int
Expand All @@ -239,14 +246,15 @@ class _EncoderScanBody(nn.Module):
deterministic: bool = True

@nn.compact
def __call__(self, carry, _):
x, mask, rope, ffn_mask = carry
def __call__(self, carry, layer_idx):
mask, rope, ffn_mask, x_accum = carry
x = EncoderBlock(
self.num_heads, self.num_kv_heads, self.d_model, self.d_ff,
self.num_layers, self.dtype, self.activation, self.dropout_rate,
self.no_feedforward,
)(x, mask, rope, ffn_mask, self.deterministic)
return (x, mask, rope, ffn_mask), None
)(x_accum, layer_idx, mask, rope, ffn_mask, self.deterministic)
x_accum = x_accum.at[:, :, layer_idx + 1, :].set(x)
return (mask, rope, ffn_mask, x_accum), None


class Encoder(nn.Module):
Expand All @@ -258,19 +266,22 @@ def __call__(self, x, mask=None, rope=None, ffn_mask=None, deterministic=True):
cfg = self.config
dt = cfg.jax_dtype
x = x.astype(dt)
x_accum = jnp.zeros((x.shape[0], x.shape[1], cfg.num_encoder_layers + 1, cfg.d_model), dtype=dt)
x_accum = x_accum.at[:, :, 0, :].set(x)
Comment on lines +269 to +270

Copilot AI Mar 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allocating and carrying x_accum with shape (B, T, num_encoder_layers+1, d_model) through the scan significantly increases activation memory (and likely compilation time) vs. a standard scan carry. If you only need the final layer output and a limited history for residual attention, consider reducing the stored history (e.g., windowed buffer) or computing the needed aggregation without materializing the full 4D tensor.

Copilot uses AI. Check for mistakes.

ScanBlock = nn.scan(
nn.remat(_EncoderScanBody),
variable_axes={"params": 0},
split_rngs={"params": True, "dropout": True},
length=cfg.num_encoder_layers,
)
(x, _, _, _), _ = ScanBlock(
(_, _, _, x_accum), _ = ScanBlock(
cfg.num_heads, cfg.num_kv_heads, cfg.d_model, cfg.d_ff,
cfg.total_layers, dt, cfg.activation, cfg.dropout_rate,
cfg.no_feedforward, deterministic, name="layers",
)((x, mask, rope, ffn_mask), None)
)((mask, rope, ffn_mask, x_accum), jnp.arange(cfg.num_encoder_layers, dtype=jnp.int32))

x = x_accum[:, :, cfg.num_encoder_layers, :]
x = ZCRMSNorm(dtype=dt, name="final_norm")(x)
return x, mask

Expand All @@ -287,36 +298,41 @@ class DecoderBlock(nn.Module):
no_feedforward: bool = True

@nn.compact
def __call__(self, x, encoder_out, self_mask=None, cross_mask=None, rope=None, ffn_mask=None, deterministic=True):
self_gate = nn.sigmoid(self.param("self_attn_gate", jinit.zeros, ())).astype(self.dtype)
residual = x
def __call__(self, encoder_out, x_accum, layer_idx, self_mask=None, cross_mask=None, rope=None, ffn_mask=None, deterministic=True):
q = self.param("residual_query", default_init(), (self.d_model,)).astype(jnp.float32)
k = x_accum.astype(jnp.float32)
v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_accum)

logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model))
valid_layers = (jnp.arange(x_accum.shape[2]) <= layer_idx)[None, None, :]
logits = jnp.where(valid_layers, logits, jnp.finfo(logits.dtype).min)
alphas = nn.sigmoid(logits).astype(self.dtype)
x = jnp.einsum("btl,btld->btd", alphas, v)

Comment on lines +302 to +311

Copilot AI Mar 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decoder block also performs residual aggregation over the entire x_accum each layer iteration, which similarly scales compute ~O(L^2) with num_decoder_layers and can become a bottleneck for deeper models/longer sequences. If possible, consider limiting the number of previous layers attended to or using a more incremental aggregation strategy to keep per-layer work closer to O(L).

Copilot uses AI. Check for mistakes.
x = ZCRMSNorm(dtype=self.dtype)(x)
x = MultiHeadAttention(self.num_heads, self.num_kv_heads, self.d_model, self.num_layers, self.dtype, name="self_attn")(
x, x, mask=self_mask, rope=rope
)
x = residual + self_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)

cross_gate = nn.sigmoid(self.param("cross_attn_gate", jinit.zeros, ())).astype(self.dtype)
residual = x
x = ZCRMSNorm(dtype=self.dtype)(x)
x = MultiHeadAttention(self.num_heads, self.num_kv_heads, self.d_model, self.num_layers, self.dtype, name="cross_attn")(
x, encoder_out, mask=cross_mask
)
x = residual + cross_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)

if not self.no_feedforward:
ffn_gate = nn.sigmoid(self.param("ffn_gate", jinit.zeros, ())).astype(self.dtype)
residual = x
x = ZCRMSNorm(dtype=self.dtype)(x)
x = FeedForward(self.d_model, self.d_ff, self.num_layers, self.dtype, self.activation)(x, ffn_mask=ffn_mask)
x = residual + ffn_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
x = ffn_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)

return x



class _DecoderScanBody(nn.Module):
"""Wraps DecoderBlock for nn.scan: carry = (x, encoder_out, self_mask, cross_mask, rope, ffn_mask)."""
"""Wraps DecoderBlock for nn.scan: carry = (encoder_out, self_mask, cross_mask, rope, ffn_mask, x_accum)."""
num_heads: int
num_kv_heads: int
d_model: int
Expand All @@ -329,14 +345,15 @@ class _DecoderScanBody(nn.Module):
deterministic: bool = True

@nn.compact
def __call__(self, carry, _):
x, encoder_out, self_mask, cross_mask, rope, ffn_mask = carry
def __call__(self, carry, layer_idx):
encoder_out, self_mask, cross_mask, rope, ffn_mask, x_accum = carry
x = DecoderBlock(
self.num_heads, self.num_kv_heads, self.d_model, self.d_ff,
self.num_layers, self.dtype, self.activation, self.dropout_rate,
self.no_feedforward,
)(x, encoder_out, self_mask, cross_mask, rope, ffn_mask, self.deterministic)
return (x, encoder_out, self_mask, cross_mask, rope, ffn_mask), None
)(encoder_out, x_accum, layer_idx, self_mask, cross_mask, rope, ffn_mask, self.deterministic)
x_accum = x_accum.at[:, :, layer_idx + 1, :].set(x)
return (encoder_out, self_mask, cross_mask, rope, ffn_mask, x_accum), None


class Decoder(nn.Module):
Expand All @@ -347,19 +364,22 @@ def __call__(self, x, encoder_out, self_mask=None, cross_mask=None, rope=None, f
cfg = self.config
dt = cfg.jax_dtype
x = x.astype(dt)
x_accum = jnp.zeros((x.shape[0], x.shape[1], cfg.num_decoder_layers + 1, cfg.d_model), dtype=dt)
x_accum = x_accum.at[:, :, 0, :].set(x)
Comment on lines +367 to +368

Copilot AI Mar 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the encoder, x_accum is allocated as a full (B, T, num_decoder_layers+1, d_model) buffer and threaded through the scan, which can be very memory-intensive for large batch/sequence/layer counts. If memory is a concern, consider storing fewer layer states (windowed) or restructuring so scan outputs are collected separately instead of keeping the entire history in the carry.

Copilot uses AI. Check for mistakes.

ScanBlock = nn.scan(
nn.remat(_DecoderScanBody),
variable_axes={"params": 0},
split_rngs={"params": True, "dropout": True},
length=cfg.num_decoder_layers,
)
(x, _, _, _, _, _), _ = ScanBlock(
(_, _, _, _, _, x_accum), _ = ScanBlock(
cfg.num_heads, cfg.num_kv_heads, cfg.d_model, cfg.d_ff,
cfg.total_layers, dt, cfg.activation, cfg.dropout_rate,
cfg.no_feedforward, deterministic, name="layers",
)((x, encoder_out, self_mask, cross_mask, rope, ffn_mask), None)
)((encoder_out, self_mask, cross_mask, rope, ffn_mask, x_accum), jnp.arange(cfg.num_decoder_layers, dtype=jnp.int32))

x = x_accum[:, :, cfg.num_decoder_layers, :]
x = ZCRMSNorm(dtype=dt)(x)
return x

Expand Down
Loading