10
10
use nf_maxpool2d_layer, only: maxpool2d_layer
11
11
use nf_reshape_layer, only: reshape3d_layer
12
12
use nf_linear2d_layer, only: linear2d_layer
13
+ use nf_self_attention_layer, only: self_attention_layer
13
14
use nf_optimizers, only: optimizer_base_type
14
15
15
16
contains
@@ -50,6 +51,8 @@ pure module subroutine backward_1d(self, previous, gradient)
50
51
call this_layer % backward(prev_layer % output, gradient)
51
52
type is (linear2d_layer)
52
53
call this_layer % backward(prev_layer % output, gradient)
54
+ type is (self_attention_layer)
55
+ call this_layer % backward(prev_layer % output, gradient)
53
56
end select
54
57
55
58
end select
@@ -72,6 +75,19 @@ pure module subroutine backward_2d(self, previous, gradient)
72
75
call this_layer % backward(prev_layer % output, gradient)
73
76
type is (linear2d_layer)
74
77
call this_layer % backward(prev_layer % output, gradient)
78
+ type is (self_attention_layer)
79
+ call this_layer % backward(prev_layer % output, gradient)
80
+ end select
81
+
82
+ type is (self_attention_layer)
83
+
84
+ select type (prev_layer = > previous % p)
85
+ type is (input2d_layer)
86
+ call this_layer % backward(prev_layer % output, gradient)
87
+ type is (linear2d_layer)
88
+ call this_layer % backward(prev_layer % output, gradient)
89
+ type is (self_attention_layer)
90
+ call this_layer % backward(prev_layer % output, gradient)
75
91
end select
76
92
77
93
end select
@@ -219,6 +235,20 @@ pure module subroutine forward(self, input)
219
235
call this_layer % forward(prev_layer % output)
220
236
type is (linear2d_layer)
221
237
call this_layer % forward(prev_layer % output)
238
+ type is (self_attention_layer)
239
+ call this_layer % forward(prev_layer % output)
240
+ end select
241
+
242
+ type is (self_attention_layer)
243
+
244
+ ! Upstream layers permitted: input2d, linear2d
245
+ select type (prev_layer = > input % p)
246
+ type is (input2d_layer)
247
+ call this_layer % forward(prev_layer % output)
248
+ type is (linear2d_layer)
249
+ call this_layer % forward(prev_layer % output)
250
+ type is (self_attention_layer)
251
+ call this_layer % forward(prev_layer % output)
222
252
end select
223
253
224
254
end select
@@ -258,6 +288,8 @@ pure module subroutine get_output_2d(self, output)
258
288
allocate (output, source= this_layer % output)
259
289
type is (linear2d_layer)
260
290
allocate (output, source= this_layer % output)
291
+ type is (self_attention_layer)
292
+ allocate (output, source= this_layer % output)
261
293
class default
262
294
error stop ' 2-d output can only be read from an input2d or linear2d layer.'
263
295
@@ -357,6 +389,8 @@ elemental module function get_num_params(self) result(num_params)
357
389
num_params = 0
358
390
type is (linear2d_layer)
359
391
num_params = this_layer % get_num_params()
392
+ type is (self_attention_layer)
393
+ num_params = this_layer % get_num_params()
360
394
class default
361
395
error stop ' Unknown layer type.'
362
396
end select
@@ -386,6 +420,8 @@ module function get_params(self) result(params)
386
420
! No parameters to get.
387
421
type is (linear2d_layer)
388
422
params = this_layer % get_params()
423
+ type is (self_attention_layer)
424
+ params = this_layer % get_params()
389
425
class default
390
426
error stop ' Unknown layer type.'
391
427
end select
@@ -415,6 +451,8 @@ module function get_gradients(self) result(gradients)
415
451
! No gradients to get.
416
452
type is (linear2d_layer)
417
453
gradients = this_layer % get_gradients()
454
+ type is (self_attention_layer)
455
+ gradients = this_layer % get_gradients()
418
456
class default
419
457
error stop ' Unknown layer type.'
420
458
end select
@@ -465,6 +503,9 @@ module subroutine set_params(self, params)
465
503
type is (linear2d_layer)
466
504
call this_layer % set_params(params)
467
505
506
+ type is (self_attention_layer)
507
+ call this_layer % set_params(params)
508
+
468
509
type is (maxpool2d_layer)
469
510
! No parameters to set.
470
511
write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
0 commit comments