Skip to content

Commit 0900990

Browse files
committed
multihead_attention: plumbing
1 parent 110cda9 commit 0900990

File tree

5 files changed

+61
-2
lines changed

5 files changed

+61
-2
lines changed

src/nf.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module nf
33
use nf_datasets_mnist, only: label_digits, load_mnist
44
use nf_layer, only: layer
55
use nf_layer_constructors, only: &
6-
conv2d, dense, flatten, input, maxpool2d, reshape, linear2d
6+
conv2d, dense, flatten, input, maxpool2d, reshape, linear2d, self_attention
77
use nf_loss, only: mse, quadratic
88
use nf_metrics, only: corr, maxabs
99
use nf_network, only: network

src/nf/nf_layer_constructors.f90

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module nf_layer_constructors
88
implicit none
99

1010
private
11-
public :: conv2d, dense, flatten, input, maxpool2d, reshape, linear2d
11+
public :: conv2d, dense, flatten, input, maxpool2d, reshape, linear2d, self_attention
1212

1313
interface input
1414

@@ -190,6 +190,11 @@ module function linear2d(sequence_length, out_features) result(res)
190190
type(layer) :: res
191191
end function linear2d
192192

193+
module function self_attention(sequence_length, model_dimension, n_heads) result(res)
194+
integer, intent(in) :: sequence_length, model_dimension, n_heads
195+
type(layer) :: res
196+
end function self_attention
197+
193198
end interface
194199

195200
end module nf_layer_constructors

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
use nf_maxpool2d_layer, only: maxpool2d_layer
1111
use nf_reshape_layer, only: reshape3d_layer
1212
use nf_linear2d_layer, only: linear2d_layer
13+
use nf_self_attention_layer, only: self_attention_layer
1314
use nf_activation, only: activation_function, relu, sigmoid
1415

1516
implicit none
@@ -159,4 +160,13 @@ module function linear2d(sequence_length, out_features) result(res)
159160
allocate(res % p, source=linear2d_layer(out_features))
160161
end function linear2d
161162

163+
module function self_attention(sequence_length, model_dimension, n_heads) result(res)
164+
integer, intent(in) :: sequence_length, model_dimension, n_heads
165+
type(layer) :: res
166+
167+
res % name = 'self_attention'
168+
res % layer_shape = [sequence_length, model_dimension]
169+
allocate(res % p, source=self_attention_layer(n_heads))
170+
end function self_attention
171+
162172
end submodule nf_layer_constructors_submodule

src/nf/nf_layer_submodule.f90

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
use nf_maxpool2d_layer, only: maxpool2d_layer
1111
use nf_reshape_layer, only: reshape3d_layer
1212
use nf_linear2d_layer, only: linear2d_layer
13+
use nf_self_attention_layer, only: self_attention_layer
1314
use nf_optimizers, only: optimizer_base_type
1415

1516
contains
@@ -50,6 +51,8 @@ pure module subroutine backward_1d(self, previous, gradient)
5051
call this_layer % backward(prev_layer % output, gradient)
5152
type is(linear2d_layer)
5253
call this_layer % backward(prev_layer % output, gradient)
54+
type is(self_attention_layer)
55+
call this_layer % backward(prev_layer % output, gradient)
5356
end select
5457

5558
end select
@@ -72,6 +75,19 @@ pure module subroutine backward_2d(self, previous, gradient)
7275
call this_layer % backward(prev_layer % output, gradient)
7376
type is(linear2d_layer)
7477
call this_layer % backward(prev_layer % output, gradient)
78+
type is(self_attention_layer)
79+
call this_layer % backward(prev_layer % output, gradient)
80+
end select
81+
82+
type is(self_attention_layer)
83+
84+
select type(prev_layer => previous % p)
85+
type is(input2d_layer)
86+
call this_layer % backward(prev_layer % output, gradient)
87+
type is(linear2d_layer)
88+
call this_layer % backward(prev_layer % output, gradient)
89+
type is(self_attention_layer)
90+
call this_layer % backward(prev_layer % output, gradient)
7591
end select
7692

7793
end select
@@ -219,6 +235,20 @@ pure module subroutine forward(self, input)
219235
call this_layer % forward(prev_layer % output)
220236
type is(linear2d_layer)
221237
call this_layer % forward(prev_layer % output)
238+
type is(self_attention_layer)
239+
call this_layer % forward(prev_layer % output)
240+
end select
241+
242+
type is(self_attention_layer)
243+
244+
! Upstream layers permitted: input2d, linear2d
245+
select type(prev_layer => input % p)
246+
type is(input2d_layer)
247+
call this_layer % forward(prev_layer % output)
248+
type is(linear2d_layer)
249+
call this_layer % forward(prev_layer % output)
250+
type is(self_attention_layer)
251+
call this_layer % forward(prev_layer % output)
222252
end select
223253

224254
end select
@@ -258,6 +288,8 @@ pure module subroutine get_output_2d(self, output)
258288
allocate(output, source=this_layer % output)
259289
type is(linear2d_layer)
260290
allocate(output, source=this_layer % output)
291+
type is(self_attention_layer)
292+
allocate(output, source=this_layer % output)
261293
class default
262294
error stop '2-d output can only be read from an input2d or linear2d layer.'
263295

@@ -357,6 +389,8 @@ elemental module function get_num_params(self) result(num_params)
357389
num_params = 0
358390
type is (linear2d_layer)
359391
num_params = this_layer % get_num_params()
392+
type is (self_attention_layer)
393+
num_params = this_layer % get_num_params()
360394
class default
361395
error stop 'Unknown layer type.'
362396
end select
@@ -386,6 +420,8 @@ module function get_params(self) result(params)
386420
! No parameters to get.
387421
type is (linear2d_layer)
388422
params = this_layer % get_params()
423+
type is (self_attention_layer)
424+
params = this_layer % get_params()
389425
class default
390426
error stop 'Unknown layer type.'
391427
end select
@@ -415,6 +451,8 @@ module function get_gradients(self) result(gradients)
415451
! No gradients to get.
416452
type is (linear2d_layer)
417453
gradients = this_layer % get_gradients()
454+
type is (self_attention_layer)
455+
gradients = this_layer % get_gradients()
418456
class default
419457
error stop 'Unknown layer type.'
420458
end select
@@ -465,6 +503,9 @@ module subroutine set_params(self, params)
465503
type is (linear2d_layer)
466504
call this_layer % set_params(params)
467505

506+
type is (self_attention_layer)
507+
call this_layer % set_params(params)
508+
468509
type is (maxpool2d_layer)
469510
! No parameters to set.
470511
write(stderr, '(a)') 'Warning: calling set_params() ' &

src/nf/nf_network_submodule.f90

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
use nf_maxpool2d_layer, only: maxpool2d_layer
1010
use nf_reshape_layer, only: reshape3d_layer
1111
use nf_linear2d_layer, only: linear2d_layer
12+
use nf_self_attention_layer, only: self_attention_layer
1213
use nf_layer, only: layer
1314
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
1415
use nf_loss, only: quadratic
@@ -158,6 +159,8 @@ module subroutine backward(self, output, loss)
158159
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
159160
type is(linear2d_layer)
160161
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
162+
type is(self_attention_layer)
163+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
161164
end select
162165
end if
163166

0 commit comments

Comments
 (0)