Skip to content

Commit 86cd7c0

Browse files
committed
multihead_attention: tests for updated parameters
1 parent bcda13d commit 86cd7c0

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

test/test_multihead_attention_layer.f90

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ program test_multihead_attention_layer
22
use iso_fortran_env, only: stderr => error_unit
33
use nf_multihead_attention_layer, only: multihead_attention_layer
44
use nf_linear2d_layer, only: linear2d_layer
5+
use nf_optimizers, only: sgd
56
implicit none
67

78
logical :: ok = .true.
@@ -21,6 +22,7 @@ program test_multihead_attention_layer
2122
call test_multihead_attention_combine_heads(attention, attention % sdpa, ok)
2223
call test_multihead_attention_forward(attention, ok)
2324
call test_multihead_attention_backward(attention, ok)
25+
call test_multihead_attention_update_gradients(attention, ok)
2426
! call test_multihead_attention_forward_reallife_shape(ok)
2527

2628
contains
@@ -239,4 +241,46 @@ subroutine test_multihead_attention_backward(attention, ok)
239241
write(stderr, '(a)') 'backward returned incorrect values.. failed'
240242
end if
241243
end subroutine test_multihead_attention_backward
244+
245+
subroutine test_multihead_attention_update_gradients(attention, ok)
246+
type(multihead_attention_layer), intent(in out) :: attention
247+
logical, intent(in out) :: ok
248+
real :: parameters(80)
249+
real :: expected_parameters(80)
250+
real :: updated_output(12)
251+
real :: expected_updated_output(12) = [&
252+
0.111365855, 0.115744293, 0.115733206, 0.185253710, 0.196646214, 0.196617395,&
253+
-0.102874994, -0.118834510, -0.118794113, 0.179314315, 0.190210193, 0.190182626&
254+
]
255+
type(sgd) :: optim
256+
257+
if (attention % get_num_params() /= 80) then
258+
ok = .false.
259+
write(stderr, '(a)') 'incorrect number of parameters.. failed'
260+
end if
261+
262+
expected_parameters(1: 64) = 0.100000001
263+
expected_parameters(65: 80) = 0.109999999
264+
parameters = attention % get_params()
265+
if (.not. all(parameters.eq.expected_parameters)) then
266+
ok = .false.
267+
write(stderr, '(a)') 'incorrect parameters.. failed'
268+
end if
269+
270+
optim = SGD(learning_rate=0.01)
271+
call optim % minimize(parameters, attention % get_gradients())
272+
call attention % set_params(parameters)
273+
274+
call attention % forward(&
275+
reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]),&
276+
reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]),&
277+
reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4])&
278+
)
279+
280+
updated_output = reshape(attention % output, [12])
281+
if (.not. all(updated_output.eq.expected_updated_output)) then
282+
ok = .false.
283+
write(stderr, '(a)') 'incorrect output after parameters update.. failed'
284+
end if
285+
end subroutine test_multihead_attention_update_gradients
242286
end program test_multihead_attention_layer

0 commit comments

Comments
 (0)