@@ -2,6 +2,7 @@ program test_multihead_attention_layer
2
2
use iso_fortran_env, only: stderr = > error_unit
3
3
use nf_multihead_attention_layer, only: multihead_attention_layer
4
4
use nf_linear2d_layer, only: linear2d_layer
5
+ use nf_optimizers, only: sgd
5
6
implicit none
6
7
7
8
logical :: ok = .true.
@@ -21,6 +22,7 @@ program test_multihead_attention_layer
21
22
call test_multihead_attention_combine_heads(attention, attention % sdpa, ok)
22
23
call test_multihead_attention_forward(attention, ok)
23
24
call test_multihead_attention_backward(attention, ok)
25
+ call test_multihead_attention_update_gradients(attention, ok)
24
26
! call test_multihead_attention_forward_reallife_shape(ok)
25
27
26
28
contains
@@ -239,4 +241,46 @@ subroutine test_multihead_attention_backward(attention, ok)
239
241
write (stderr, ' (a)' ) ' backward returned incorrect values.. failed'
240
242
end if
241
243
end subroutine test_multihead_attention_backward
244
+
245
+ subroutine test_multihead_attention_update_gradients (attention , ok )
246
+ type (multihead_attention_layer), intent (in out ) :: attention
247
+ logical , intent (in out ) :: ok
248
+ real :: parameters(80 )
249
+ real :: expected_parameters(80 )
250
+ real :: updated_output(12 )
251
+ real :: expected_updated_output(12 ) = [&
252
+ 0.111365855 , 0.115744293 , 0.115733206 , 0.185253710 , 0.196646214 , 0.196617395 ,&
253
+ - 0.102874994 , - 0.118834510 , - 0.118794113 , 0.179314315 , 0.190210193 , 0.190182626 &
254
+ ]
255
+ type (sgd) :: optim
256
+
257
+ if (attention % get_num_params() /= 80 ) then
258
+ ok = .false.
259
+ write (stderr, ' (a)' ) ' incorrect number of parameters.. failed'
260
+ end if
261
+
262
+ expected_parameters(1 : 64 ) = 0.100000001
263
+ expected_parameters(65 : 80 ) = 0.109999999
264
+ parameters = attention % get_params()
265
+ if (.not. all (parameters.eq. expected_parameters)) then
266
+ ok = .false.
267
+ write (stderr, ' (a)' ) ' incorrect parameters.. failed'
268
+ end if
269
+
270
+ optim = SGD(learning_rate= 0.01 )
271
+ call optim % minimize(parameters, attention % get_gradients())
272
+ call attention % set_params(parameters)
273
+
274
+ call attention % forward(&
275
+ reshape ([0.0 , 10.1 , 0.2 , 10.3 , 0.4 , 10.5 , 0.6 , 10.7 , 10.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 ]),&
276
+ reshape ([0.0 , 10.1 , 0.2 , 10.3 , 0.4 , 10.5 , 0.6 , 10.7 , 10.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 ]),&
277
+ reshape ([0.0 , 10.1 , 0.2 , 10.3 , 0.4 , 10.5 , 0.6 , 10.7 , 10.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 ])&
278
+ )
279
+
280
+ updated_output = reshape (attention % output, [12 ])
281
+ if (.not. all (updated_output.eq. expected_updated_output)) then
282
+ ok = .false.
283
+ write (stderr, ' (a)' ) ' incorrect output after parameters update.. failed'
284
+ end if
285
+ end subroutine test_multihead_attention_update_gradients
242
286
end program test_multihead_attention_layer
0 commit comments