-
Notifications
You must be signed in to change notification settings - Fork 179
added residual attention, removed normal residual and gated attention #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -194,7 +194,7 @@ def __call__(self, mel): | |
|
|
||
|
|
||
| class EncoderBlock(nn.Module): | ||
| """Pre-norm self-attention + pre-norm FFN with gated residual connections.""" | ||
| """Layer input is read from accumulated states, then transformed by attention/FFN.""" | ||
| num_heads: int | ||
| num_kv_heads: int | ||
| d_model: int | ||
|
|
@@ -206,27 +206,34 @@ class EncoderBlock(nn.Module): | |
| no_feedforward: bool = True | ||
|
|
||
| @nn.compact | ||
| def __call__(self, x, mask=None, rope=None, ffn_mask=None, deterministic=True): | ||
| gate = nn.sigmoid(self.param("attn_gate", jinit.zeros, ())).astype(self.dtype) | ||
| residual = x | ||
| def __call__(self, x_accum, layer_idx, mask=None, rope=None, ffn_mask=None, deterministic=True): | ||
| q = self.param("residual_query", default_init(), (self.d_model,)).astype(jnp.float32) | ||
| k = x_accum.astype(jnp.float32) | ||
| v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_accum) | ||
|
|
||
| logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model)) | ||
| valid_layers = (jnp.arange(x_accum.shape[2]) <= layer_idx)[None, None, :] | ||
| logits = jnp.where(valid_layers, logits, jnp.finfo(logits.dtype).min) | ||
| alphas = nn.sigmoid(logits).astype(self.dtype) | ||
| x = jnp.einsum("btl,btld->btd", alphas, v) | ||
|
|
||
| x = ZCRMSNorm(dtype=self.dtype)(x) | ||
| x = MultiHeadAttention(self.num_heads, self.num_kv_heads, self.d_model, self.num_layers, self.dtype, name="self_attn")( | ||
| x, x, mask=mask, rope=rope | ||
| ) | ||
| x = residual + gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) | ||
| x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) | ||
|
|
||
| if not self.no_feedforward: | ||
| ffn_gate = nn.sigmoid(self.param("ffn_gate", jinit.zeros, ())).astype(self.dtype) | ||
| residual = x | ||
| x = ZCRMSNorm(dtype=self.dtype)(x) | ||
| x = FeedForward(self.d_model, self.d_ff, self.num_layers, self.dtype, self.activation)(x, ffn_mask=ffn_mask) | ||
| x = residual + ffn_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) | ||
| x = ffn_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) | ||
|
|
||
| return x | ||
|
|
||
|
|
||
| class _EncoderScanBody(nn.Module): | ||
| """Wraps EncoderBlock for nn.scan: carry = (x, mask, rope, ffn_mask).""" | ||
| """Wraps EncoderBlock for nn.scan: carry = (mask, rope, ffn_mask, x_accum).""" | ||
| num_heads: int | ||
| num_kv_heads: int | ||
| d_model: int | ||
|
|
@@ -239,14 +246,15 @@ class _EncoderScanBody(nn.Module): | |
| deterministic: bool = True | ||
|
|
||
| @nn.compact | ||
| def __call__(self, carry, _): | ||
| x, mask, rope, ffn_mask = carry | ||
| def __call__(self, carry, layer_idx): | ||
| mask, rope, ffn_mask, x_accum = carry | ||
| x = EncoderBlock( | ||
| self.num_heads, self.num_kv_heads, self.d_model, self.d_ff, | ||
| self.num_layers, self.dtype, self.activation, self.dropout_rate, | ||
| self.no_feedforward, | ||
| )(x, mask, rope, ffn_mask, self.deterministic) | ||
| return (x, mask, rope, ffn_mask), None | ||
| )(x_accum, layer_idx, mask, rope, ffn_mask, self.deterministic) | ||
| x_accum = x_accum.at[:, :, layer_idx + 1, :].set(x) | ||
| return (mask, rope, ffn_mask, x_accum), None | ||
|
|
||
|
|
||
| class Encoder(nn.Module): | ||
|
|
@@ -258,19 +266,22 @@ def __call__(self, x, mask=None, rope=None, ffn_mask=None, deterministic=True): | |
| cfg = self.config | ||
| dt = cfg.jax_dtype | ||
| x = x.astype(dt) | ||
| x_accum = jnp.zeros((x.shape[0], x.shape[1], cfg.num_encoder_layers + 1, cfg.d_model), dtype=dt) | ||
| x_accum = x_accum.at[:, :, 0, :].set(x) | ||
|
Comment on lines
+269
to
+270
|
||
|
|
||
| ScanBlock = nn.scan( | ||
| nn.remat(_EncoderScanBody), | ||
| variable_axes={"params": 0}, | ||
| split_rngs={"params": True, "dropout": True}, | ||
| length=cfg.num_encoder_layers, | ||
| ) | ||
| (x, _, _, _), _ = ScanBlock( | ||
| (_, _, _, x_accum), _ = ScanBlock( | ||
| cfg.num_heads, cfg.num_kv_heads, cfg.d_model, cfg.d_ff, | ||
| cfg.total_layers, dt, cfg.activation, cfg.dropout_rate, | ||
| cfg.no_feedforward, deterministic, name="layers", | ||
| )((x, mask, rope, ffn_mask), None) | ||
| )((mask, rope, ffn_mask, x_accum), jnp.arange(cfg.num_encoder_layers, dtype=jnp.int32)) | ||
|
|
||
| x = x_accum[:, :, cfg.num_encoder_layers, :] | ||
| x = ZCRMSNorm(dtype=dt, name="final_norm")(x) | ||
| return x, mask | ||
|
|
||
|
|
@@ -287,36 +298,41 @@ class DecoderBlock(nn.Module): | |
| no_feedforward: bool = True | ||
|
|
||
| @nn.compact | ||
| def __call__(self, x, encoder_out, self_mask=None, cross_mask=None, rope=None, ffn_mask=None, deterministic=True): | ||
| self_gate = nn.sigmoid(self.param("self_attn_gate", jinit.zeros, ())).astype(self.dtype) | ||
| residual = x | ||
| def __call__(self, encoder_out, x_accum, layer_idx, self_mask=None, cross_mask=None, rope=None, ffn_mask=None, deterministic=True): | ||
| q = self.param("residual_query", default_init(), (self.d_model,)).astype(jnp.float32) | ||
| k = x_accum.astype(jnp.float32) | ||
| v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_accum) | ||
|
|
||
| logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model)) | ||
| valid_layers = (jnp.arange(x_accum.shape[2]) <= layer_idx)[None, None, :] | ||
| logits = jnp.where(valid_layers, logits, jnp.finfo(logits.dtype).min) | ||
| alphas = nn.sigmoid(logits).astype(self.dtype) | ||
| x = jnp.einsum("btl,btld->btd", alphas, v) | ||
|
|
||
|
Comment on lines
+302
to
+311
|
||
| x = ZCRMSNorm(dtype=self.dtype)(x) | ||
| x = MultiHeadAttention(self.num_heads, self.num_kv_heads, self.d_model, self.num_layers, self.dtype, name="self_attn")( | ||
| x, x, mask=self_mask, rope=rope | ||
| ) | ||
| x = residual + self_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) | ||
| x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) | ||
|
|
||
| cross_gate = nn.sigmoid(self.param("cross_attn_gate", jinit.zeros, ())).astype(self.dtype) | ||
| residual = x | ||
| x = ZCRMSNorm(dtype=self.dtype)(x) | ||
| x = MultiHeadAttention(self.num_heads, self.num_kv_heads, self.d_model, self.num_layers, self.dtype, name="cross_attn")( | ||
| x, encoder_out, mask=cross_mask | ||
| ) | ||
| x = residual + cross_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) | ||
| x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) | ||
|
|
||
| if not self.no_feedforward: | ||
| ffn_gate = nn.sigmoid(self.param("ffn_gate", jinit.zeros, ())).astype(self.dtype) | ||
| residual = x | ||
| x = ZCRMSNorm(dtype=self.dtype)(x) | ||
| x = FeedForward(self.d_model, self.d_ff, self.num_layers, self.dtype, self.activation)(x, ffn_mask=ffn_mask) | ||
| x = residual + ffn_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) | ||
| x = ffn_gate * nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) | ||
|
|
||
| return x | ||
|
|
||
|
|
||
|
|
||
| class _DecoderScanBody(nn.Module): | ||
| """Wraps DecoderBlock for nn.scan: carry = (x, encoder_out, self_mask, cross_mask, rope, ffn_mask).""" | ||
| """Wraps DecoderBlock for nn.scan: carry = (encoder_out, self_mask, cross_mask, rope, ffn_mask, x_accum).""" | ||
| num_heads: int | ||
| num_kv_heads: int | ||
| d_model: int | ||
|
|
@@ -329,14 +345,15 @@ class _DecoderScanBody(nn.Module): | |
| deterministic: bool = True | ||
|
|
||
| @nn.compact | ||
| def __call__(self, carry, _): | ||
| x, encoder_out, self_mask, cross_mask, rope, ffn_mask = carry | ||
| def __call__(self, carry, layer_idx): | ||
| encoder_out, self_mask, cross_mask, rope, ffn_mask, x_accum = carry | ||
| x = DecoderBlock( | ||
| self.num_heads, self.num_kv_heads, self.d_model, self.d_ff, | ||
| self.num_layers, self.dtype, self.activation, self.dropout_rate, | ||
| self.no_feedforward, | ||
| )(x, encoder_out, self_mask, cross_mask, rope, ffn_mask, self.deterministic) | ||
| return (x, encoder_out, self_mask, cross_mask, rope, ffn_mask), None | ||
| )(encoder_out, x_accum, layer_idx, self_mask, cross_mask, rope, ffn_mask, self.deterministic) | ||
| x_accum = x_accum.at[:, :, layer_idx + 1, :].set(x) | ||
| return (encoder_out, self_mask, cross_mask, rope, ffn_mask, x_accum), None | ||
|
|
||
|
|
||
| class Decoder(nn.Module): | ||
|
|
@@ -347,19 +364,22 @@ def __call__(self, x, encoder_out, self_mask=None, cross_mask=None, rope=None, f | |
| cfg = self.config | ||
| dt = cfg.jax_dtype | ||
| x = x.astype(dt) | ||
| x_accum = jnp.zeros((x.shape[0], x.shape[1], cfg.num_decoder_layers + 1, cfg.d_model), dtype=dt) | ||
| x_accum = x_accum.at[:, :, 0, :].set(x) | ||
|
Comment on lines
+367
to
+368
|
||
|
|
||
| ScanBlock = nn.scan( | ||
| nn.remat(_DecoderScanBody), | ||
| variable_axes={"params": 0}, | ||
| split_rngs={"params": True, "dropout": True}, | ||
| length=cfg.num_decoder_layers, | ||
| ) | ||
| (x, _, _, _, _, _), _ = ScanBlock( | ||
| (_, _, _, _, _, x_accum), _ = ScanBlock( | ||
| cfg.num_heads, cfg.num_kv_heads, cfg.d_model, cfg.d_ff, | ||
| cfg.total_layers, dt, cfg.activation, cfg.dropout_rate, | ||
| cfg.no_feedforward, deterministic, name="layers", | ||
| )((x, encoder_out, self_mask, cross_mask, rope, ffn_mask), None) | ||
| )((encoder_out, self_mask, cross_mask, rope, ffn_mask, x_accum), jnp.arange(cfg.num_decoder_layers, dtype=jnp.int32)) | ||
|
|
||
| x = x_accum[:, :, cfg.num_decoder_layers, :] | ||
| x = ZCRMSNorm(dtype=dt)(x) | ||
| return x | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This residual-attention computation reads the full
x_accumtensor (B×T×(L+1)×D) and computes logits/values over all layers each iteration, making the encoder block’s cost scale ~O(L^2) in number of layers and increasing per-step memory bandwidth substantially. If this is not strictly required, consider restricting the aggregation to a bounded window of previous layers or maintaining a smaller running summary so each layer doesn’t repeatedly process the entire accumulated state.