Skip to content

Multihead Attention Fixes #209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
40 changes: 34 additions & 6 deletions src/nf/nf_multihead_attention.f90
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,25 @@ module nf_multihead_attention_layer
real, allocatable :: k_input(:, :)
real, allocatable :: v_input(:, :)
real, allocatable :: o_input(:, :)

! temporary storages for forward and backward passes
real, allocatable :: normalized_attention(:, :, :)
real, allocatable :: q_or_dq(:, :, :)
real, allocatable :: k_or_dk(:, :, :)
real, allocatable :: v_or_dv(:, :, :)
real, allocatable :: d_output(:, :, :)
real, allocatable :: v_heads(:, :, :)
real, allocatable :: k_heads(:, :, :)
real, allocatable :: q_heads(:, :, :)
real, allocatable :: d_sdpa(:, :)
real, allocatable :: jacobian(:, :)
real, allocatable :: d_normalize(:, :, :)
contains

procedure :: common_backward
procedure :: common_forward
procedure :: sdpa_forward
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering what was sdpa until I found it in one of your comments below (that is, Scaled Dot Product Attention).
I suggest to add a comment to explain it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do!

procedure :: sdpa_backward
procedure :: get_num_params
procedure :: get_params
procedure :: get_gradients
Expand All @@ -68,25 +83,38 @@ end function multihead_attention_layer_cons

interface

pure module subroutine common_backward(self, input, gradient)
pure module subroutine common_backward(self, input, gradient, attention_mask)
!! General backprop for MultiHead Attention mechanism
!! Might be used for both Self and Cross Attention
!! Self Attention: sum output gradients
!! Cross Attention: use them separately
class(multihead_attention_layer), intent(in out) :: self
real, intent(in) :: input(:, :)
real, intent(in) :: gradient(:, :)
real, optional, intent(in) :: attention_mask(:, :)
end subroutine common_backward

pure module subroutine common_forward(self, query, key, value)
pure module subroutine common_forward(self, query, key, value, attention_mask)
!! General forward propagation for MultiHead Attention Mechanism
!! Might be used for both Self and Cross Attention
!! Self Attention: pass the same value thrice
!! Cross Attention: pass three values for your query, key and value
class(multihead_attention_layer), intent(in out) :: self
real, intent(in) :: query(:, :), key(:, :), value(:, :)
real, optional, intent(in) :: attention_mask(:, :)
end subroutine common_forward

pure module subroutine sdpa_forward(self, attention_mask)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put Scaled Dot Product Attention into a separate method. This adds more flexibility.
In some cases we need to do manipulations with input projections, such as KV Caching for LLama and Qwen2.

class(multihead_attention_layer), intent(in out) :: self
real, intent(in), optional :: attention_mask(:, :)
end subroutine sdpa_forward

pure module subroutine sdpa_backward(self, gradient, attention_mask)
class(multihead_attention_layer), intent(in out) :: self
real, intent(in) :: gradient(:, :)
real, intent(in), optional :: attention_mask(:, :)
end subroutine sdpa_backward

pure module subroutine init(self, input_shape)
!! Initialize the layer data structures.
!!
Expand Down Expand Up @@ -119,7 +147,7 @@ pure module subroutine normalize_attention_matrix(self, attention_mask)
!! Output dims: sequence_length, sequence_length, n_heads
class(multihead_attention_layer), intent(in out) :: self
!! (sequence_length, sequence_length, n_heads)
real, optional, intent(in) :: attention_mask(:, :, :)
real, optional, intent(in) :: attention_mask(:, :)
!! (sequence_length, sequence_length, n_heads)
end subroutine normalize_attention_matrix

Expand All @@ -143,18 +171,18 @@ elemental module function get_num_params(self) result(num_params)
end function get_num_params

module function get_params(self) result(params)
class(multihead_attention_layer), intent(in), target :: self
class(multihead_attention_layer), intent(in) :: self
real, allocatable :: params(:)
end function get_params

module function get_gradients(self) result(gradients)
class(multihead_attention_layer), intent(in), target :: self
class(multihead_attention_layer), intent(in) :: self
real, allocatable :: gradients(:)
end function get_gradients

module subroutine set_params(self, params)
class(multihead_attention_layer), intent(in out) :: self
real, intent(in), target :: params(:)
real, intent(in) :: params(:)
end subroutine set_params

module subroutine init_base(self, input_shape)
Expand Down
Loading