From a8d5fd0c99a0ce3e2864080fd772c813246af683 Mon Sep 17 00:00:00 2001 From: Karen Mosoyan Date: Fri, 27 Mar 2026 13:08:58 -0700 Subject: [PATCH] added residual attention, removed normal residual and gated attention --- src/model.py | 80 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 50 insertions(+), 30 deletions(-) diff --git a/src/model.py b/src/model.py index 731b8c0..f8be658 100644 --- a/src/model.py +++ b/src/model.py @@ -194,7 +194,7 @@ def __call__(self, mel): class EncoderBlock(nn.Module): - """Pre-norm self-attention + pre-norm FFN with gated residual connections.""" + """Layer input is read from accumulated states, then transformed by attention/FFN.""" num_heads: int num_kv_heads: int d_model: int @@ -206,27 +206,34 @@ class EncoderBlock(nn.Module): no_feedforward: bool = True @nn.compact - def __call__(self, x, mask=None, rope=None, ffn_mask=None, deterministic=True): - gate = nn.sigmoid(self.param("attn_gate", jinit.zeros, ())).astype(self.dtype) - residual = x + def __call__(self, x_accum, layer_idx, mask=None, rope=None, ffn_mask=None, deterministic=True): + q = self.param("residual_query", default_init(), (self.d_model,)).astype(jnp.float32) + k = x_accum.astype(jnp.float32) + v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_accum) + + logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model)) + valid_layers = (jnp.arange(x_accum.shape[2]) <= layer_idx)[None, None, :] + logits = jnp.where(valid_layers, logits, jnp.finfo(logits.dtype).min) + alphas = nn.sigmoid(logits).astype(self.dtype) + x = jnp.einsum("btl,btld->btd", alphas, v) + x = ZCRMSNorm(dtype=self.dtype)(x) x = MultiHeadAttention(self.num_heads, self.num_kv_heads, self.d_model, self.num_layers, self.dtype, name="self_attn")( x, x, mask=mask, rope=rope ) - x = residual + gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) if not self.no_feedforward: ffn_gate = nn.sigmoid(self.param("ffn_gate", jinit.zeros, ())).astype(self.dtype) - residual = x x = ZCRMSNorm(dtype=self.dtype)(x) x = FeedForward(self.d_model, self.d_ff, self.num_layers, self.dtype, self.activation)(x, ffn_mask=ffn_mask) - x = residual + ffn_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) + x = ffn_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) return x class _EncoderScanBody(nn.Module): - """Wraps EncoderBlock for nn.scan: carry = (x, mask, rope, ffn_mask).""" + """Wraps EncoderBlock for nn.scan: carry = (mask, rope, ffn_mask, x_accum).""" num_heads: int num_kv_heads: int d_model: int @@ -239,14 +246,15 @@ class _EncoderScanBody(nn.Module): deterministic: bool = True @nn.compact - def __call__(self, carry, _): - x, mask, rope, ffn_mask = carry + def __call__(self, carry, layer_idx): + mask, rope, ffn_mask, x_accum = carry x = EncoderBlock( self.num_heads, self.num_kv_heads, self.d_model, self.d_ff, self.num_layers, self.dtype, self.activation, self.dropout_rate, self.no_feedforward, - )(x, mask, rope, ffn_mask, self.deterministic) - return (x, mask, rope, ffn_mask), None + )(x_accum, layer_idx, mask, rope, ffn_mask, self.deterministic) + x_accum = x_accum.at[:, :, layer_idx + 1, :].set(x) + return (mask, rope, ffn_mask, x_accum), None class Encoder(nn.Module): @@ -258,6 +266,8 @@ def __call__(self, x, mask=None, rope=None, ffn_mask=None, deterministic=True): cfg = self.config dt = cfg.jax_dtype x = x.astype(dt) + x_accum = jnp.zeros((x.shape[0], x.shape[1], cfg.num_encoder_layers + 1, cfg.d_model), dtype=dt) + x_accum = x_accum.at[:, :, 0, :].set(x) ScanBlock = nn.scan( nn.remat(_EncoderScanBody), @@ -265,12 +275,13 @@ def __call__(self, x, mask=None, rope=None, ffn_mask=None, deterministic=True): split_rngs={"params": True, "dropout": True}, length=cfg.num_encoder_layers, ) - (x, _, _, _), _ = ScanBlock( + (_, _, _, x_accum), _ = ScanBlock( cfg.num_heads, cfg.num_kv_heads, cfg.d_model, cfg.d_ff, cfg.total_layers, dt, cfg.activation, cfg.dropout_rate, cfg.no_feedforward, deterministic, name="layers", - )((x, mask, rope, ffn_mask), None) + )((mask, rope, ffn_mask, x_accum), jnp.arange(cfg.num_encoder_layers, dtype=jnp.int32)) + x = x_accum[:, :, cfg.num_encoder_layers, :] x = ZCRMSNorm(dtype=dt, name="final_norm")(x) return x, mask @@ -287,36 +298,41 @@ class DecoderBlock(nn.Module): no_feedforward: bool = True @nn.compact - def __call__(self, x, encoder_out, self_mask=None, cross_mask=None, rope=None, ffn_mask=None, deterministic=True): - self_gate = nn.sigmoid(self.param("self_attn_gate", jinit.zeros, ())).astype(self.dtype) - residual = x + def __call__(self, encoder_out, x_accum, layer_idx, self_mask=None, cross_mask=None, rope=None, ffn_mask=None, deterministic=True): + q = self.param("residual_query", default_init(), (self.d_model,)).astype(jnp.float32) + k = x_accum.astype(jnp.float32) + v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_accum) + + logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model)) + valid_layers = (jnp.arange(x_accum.shape[2]) <= layer_idx)[None, None, :] + logits = jnp.where(valid_layers, logits, jnp.finfo(logits.dtype).min) + alphas = nn.sigmoid(logits).astype(self.dtype) + x = jnp.einsum("btl,btld->btd", alphas, v) + x = ZCRMSNorm(dtype=self.dtype)(x) x = MultiHeadAttention(self.num_heads, self.num_kv_heads, self.d_model, self.num_layers, self.dtype, name="self_attn")( x, x, mask=self_mask, rope=rope ) - x = residual + self_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) - cross_gate = nn.sigmoid(self.param("cross_attn_gate", jinit.zeros, ())).astype(self.dtype) - residual = x x = ZCRMSNorm(dtype=self.dtype)(x) x = MultiHeadAttention(self.num_heads, self.num_kv_heads, self.d_model, self.num_layers, self.dtype, name="cross_attn")( x, encoder_out, mask=cross_mask ) - x = residual + cross_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) if not self.no_feedforward: ffn_gate = nn.sigmoid(self.param("ffn_gate", jinit.zeros, ())).astype(self.dtype) - residual = x x = ZCRMSNorm(dtype=self.dtype)(x) x = FeedForward(self.d_model, self.d_ff, self.num_layers, self.dtype, self.activation)(x, ffn_mask=ffn_mask) - x = residual + ffn_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) + x = ffn_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) return x class _DecoderScanBody(nn.Module): - """Wraps DecoderBlock for nn.scan: carry = (x, encoder_out, self_mask, cross_mask, rope, ffn_mask).""" + """Wraps DecoderBlock for nn.scan: carry = (encoder_out, self_mask, cross_mask, rope, ffn_mask, x_accum).""" num_heads: int num_kv_heads: int d_model: int @@ -329,14 +345,15 @@ class _DecoderScanBody(nn.Module): deterministic: bool = True @nn.compact - def __call__(self, carry, _): - x, encoder_out, self_mask, cross_mask, rope, ffn_mask = carry + def __call__(self, carry, layer_idx): + encoder_out, self_mask, cross_mask, rope, ffn_mask, x_accum = carry x = DecoderBlock( self.num_heads, self.num_kv_heads, self.d_model, self.d_ff, self.num_layers, self.dtype, self.activation, self.dropout_rate, self.no_feedforward, - )(x, encoder_out, self_mask, cross_mask, rope, ffn_mask, self.deterministic) - return (x, encoder_out, self_mask, cross_mask, rope, ffn_mask), None + )(encoder_out, x_accum, layer_idx, self_mask, cross_mask, rope, ffn_mask, self.deterministic) + x_accum = x_accum.at[:, :, layer_idx + 1, :].set(x) + return (encoder_out, self_mask, cross_mask, rope, ffn_mask, x_accum), None class Decoder(nn.Module): @@ -347,6 +364,8 @@ def __call__(self, x, encoder_out, self_mask=None, cross_mask=None, rope=None, f cfg = self.config dt = cfg.jax_dtype x = x.astype(dt) + x_accum = jnp.zeros((x.shape[0], x.shape[1], cfg.num_decoder_layers + 1, cfg.d_model), dtype=dt) + x_accum = x_accum.at[:, :, 0, :].set(x) ScanBlock = nn.scan( nn.remat(_DecoderScanBody), @@ -354,12 +373,13 @@ def __call__(self, x, encoder_out, self_mask=None, cross_mask=None, rope=None, f split_rngs={"params": True, "dropout": True}, length=cfg.num_decoder_layers, ) - (x, _, _, _, _, _), _ = ScanBlock( + (_, _, _, _, _, x_accum), _ = ScanBlock( cfg.num_heads, cfg.num_kv_heads, cfg.d_model, cfg.d_ff, cfg.total_layers, dt, cfg.activation, cfg.dropout_rate, cfg.no_feedforward, deterministic, name="layers", - )((x, encoder_out, self_mask, cross_mask, rope, ffn_mask), None) + )((encoder_out, self_mask, cross_mask, rope, ffn_mask, x_accum), jnp.arange(cfg.num_decoder_layers, dtype=jnp.int32)) + x = x_accum[:, :, cfg.num_decoder_layers, :] x = ZCRMSNorm(dtype=dt)(x) return x