From dfb9a11c5f54aad4851725e88d1dee6b17296471 Mon Sep 17 00:00:00 2001 From: Chamath Adithya Date: Sat, 30 May 2026 01:51:15 +0530 Subject: [PATCH] fix: add missing output projection in MultiHeadAttention and optimize training --- README.md | 5 ++++- scripts/train_transformer.py | 4 ++++ src/models/attention.py | 3 +++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c3ccd4e..0bbc5a4 100644 --- a/README.md +++ b/README.md @@ -640,6 +640,7 @@ class MultiHeadAttention(nn.Module): def __init__(self, n_head, n_embed, context_length): super().__init__() self.heads = nn.ModuleList([Head(n_embed // n_head, n_embed, context_length) for _ in range(n_head)]) + self.proj = nn.Linear(n_embed, n_embed) def forward(self, x): """ @@ -653,10 +654,12 @@ class MultiHeadAttention(nn.Module): """ # Concatenate the output of each head along the last dimension (C) x = torch.cat([h(x) for h in self.heads], dim=-1) + # Apply final linear projection + x = self.proj(x) return x ``` -Now that we have defined the MultiHeadAttention class, which combines multiple attention heads, the __init__ method initializes a list of Head instances (a total of n_head), each with a head_size of n_embed // n_head. The forward method applies each attention head to the input x and concatenates their outputs along the last dimension, merging the information learned by each head. +Now that we have defined the MultiHeadAttention class, which combines multiple attention heads, the __init__ method initializes a list of Head instances (a total of n_head), each with a head_size of n_embed // n_head, along with a final projection layer. The forward method applies each attention head to the input x, concatenates their outputs along the last dimension, and projects them linearly to merge the information learned by each head. ### Transformer Block diff --git a/scripts/train_transformer.py b/scripts/train_transformer.py index 8cbb14a..de063ce 100644 --- a/scripts/train_transformer.py +++ b/scripts/train_transformer.py @@ -103,6 +103,10 @@ def estimate_loss(steps: int) -> Dict[str, float]: # Backpropagate the loss and update the model parameters. optimizer.zero_grad(set_to_none=True) loss.backward() + + # Clip gradients to prevent exploding gradients + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() # Periodically evaluate the model on training and development data. diff --git a/src/models/attention.py b/src/models/attention.py index c57ccd7..1596c94 100644 --- a/src/models/attention.py +++ b/src/models/attention.py @@ -80,6 +80,7 @@ def __init__(self, n_head: int, n_embed: int, context_length: int) -> None: """ super().__init__() self.heads = nn.ModuleList([Head(n_embed // n_head, n_embed, context_length) for _ in range(n_head)]) + self.proj = nn.Linear(n_embed, n_embed) def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -93,6 +94,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ # Concatenate the output of each head along the last dimension (C) x = torch.cat([h(x) for h in self.heads], dim=-1) + # Apply final linear projection + x = self.proj(x) return x if __name__ == '__main__':