Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ class MultiHeadAttention(nn.Module):
def __init__(self, n_head, n_embed, context_length):
super().__init__()
self.heads = nn.ModuleList([Head(n_embed // n_head, n_embed, context_length) for _ in range(n_head)])
self.proj = nn.Linear(n_embed, n_embed)

def forward(self, x):
"""
Expand All @@ -653,10 +654,12 @@ class MultiHeadAttention(nn.Module):
"""
# Concatenate the output of each head along the last dimension (C)
x = torch.cat([h(x) for h in self.heads], dim=-1)
# Apply final linear projection
x = self.proj(x)
return x
```

Now that we have defined the MultiHeadAttention class, which combines multiple attention heads, the __init__ method initializes a list of Head instances (a total of n_head), each with a head_size of n_embed // n_head. The forward method applies each attention head to the input x and concatenates their outputs along the last dimension, merging the information learned by each head.
Now that we have defined the MultiHeadAttention class, which combines multiple attention heads, the __init__ method initializes a list of Head instances (a total of n_head), each with a head_size of n_embed // n_head, along with a final projection layer. The forward method applies each attention head to the input x, concatenates their outputs along the last dimension, and projects them linearly to merge the information learned by each head.

### Transformer Block

Expand Down
4 changes: 4 additions & 0 deletions scripts/train_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def estimate_loss(steps: int) -> Dict[str, float]:
# Backpropagate the loss and update the model parameters.
optimizer.zero_grad(set_to_none=True)
loss.backward()

# Clip gradients to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

optimizer.step()

# Periodically evaluate the model on training and development data.
Expand Down
3 changes: 3 additions & 0 deletions src/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(self, n_head: int, n_embed: int, context_length: int) -> None:
"""
super().__init__()
self.heads = nn.ModuleList([Head(n_embed // n_head, n_embed, context_length) for _ in range(n_head)])
self.proj = nn.Linear(n_embed, n_embed)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -93,6 +94,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
# Concatenate the output of each head along the last dimension (C)
x = torch.cat([h(x) for h in self.heads], dim=-1)
# Apply final linear projection
x = self.proj(x)
return x

if __name__ == '__main__':
Expand Down