Skip to content

Commit e82d565

Browse files
jvdp1Vandenplas, Jeremiemilancurcic
authored
Proposition of API for the method network % evaluate (#182)
* proposition of API for the method evaluate * nf_metric -> nf_metrics for consistency with Python frameworks * Add nf_metrics.f90 to the CMake build * Make corr metric public * Formatting * Bump minor version * Make metrics accessible via nf * Evaluate metrics in MNIST example * Add simple tests for metrics * addition of maxabs * Update example * Remove multri-metrics variant of net % evaluate * Mention metrics in README --------- Co-authored-by: Vandenplas, Jeremie <[email protected]> Co-authored-by: milancurcic <[email protected]>
1 parent 6dfaed0 commit e82d565

11 files changed

+204
-12
lines changed

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ add_library(neural-fortran
5050
src/nf/nf_loss_submodule.f90
5151
src/nf/nf_maxpool2d_layer.f90
5252
src/nf/nf_maxpool2d_layer_submodule.f90
53+
src/nf/nf_metrics.f90
5354
src/nf/nf_network.f90
5455
src/nf/nf_network_submodule.f90
5556
src/nf/nf_optimizers.f90

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
2020
* Stochastic gradient descent optimizers: Classic, momentum, Nesterov momentum,
2121
RMSProp, Adagrad, Adam, AdamW
2222
* More than a dozen activation functions and their derivatives
23+
* Loss functions and metrics: Quadratic, Mean Squared Error, Pearson Correlation etc.
2324
* Loading dense and convolutional models from Keras HDF5 (.h5) files
2425
* Data-based parallelism
2526

example/dense_mnist.f90

+12-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
program dense_mnist
22

3-
use nf, only: dense, input, network, sgd, label_digits, load_mnist
3+
use nf, only: dense, input, network, sgd, label_digits, load_mnist, corr
44

55
implicit none
66

@@ -38,9 +38,17 @@ program dense_mnist
3838
optimizer=sgd(learning_rate=3.) &
3939
)
4040

41-
if (this_image() == 1) &
42-
print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( &
43-
net, validation_images, label_digits(validation_labels)) * 100, ' %'
41+
block
42+
real, allocatable :: output_metrics(:,:)
43+
real, allocatable :: mean_metrics(:)
44+
! 2 metrics; 1st is default loss function (quadratic), other is Pearson corr.
45+
output_metrics = net % evaluate(validation_images, label_digits(validation_labels), metric=corr())
46+
mean_metrics = sum(output_metrics, 1) / size(output_metrics, 1)
47+
if (this_image() == 1) &
48+
print '(a,i2,3(a,f6.3))', 'Epoch ', n, ' done, Accuracy: ', &
49+
accuracy(net, validation_images, label_digits(validation_labels)) * 100, &
50+
'%, Loss: ', mean_metrics(1), ', Pearson correlation: ', mean_metrics(2)
51+
end block
4452

4553
end do epochs
4654

fpm.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name = "neural-fortran"
2-
version = "0.16.1"
2+
version = "0.17.0"
33
license = "MIT"
44
author = "Milan Curcic"
55
maintainer = "[email protected]"

src/nf.f90

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ module nf
55
use nf_layer_constructors, only: &
66
conv2d, dense, flatten, input, maxpool2d, reshape
77
use nf_loss, only: mse, quadratic
8+
use nf_metrics, only: corr, maxabs
89
use nf_network, only: network
910
use nf_optimizers, only: sgd, rmsprop, adam, adagrad
1011
use nf_activation, only: activation_function, elu, exponential, &

src/nf/nf_loss.f90

+2-7
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,20 @@ module nf_loss
77
!! loss type that extends the abstract loss derived type, and that
88
!! implements concrete eval and derivative methods that accept vectors.
99

10+
use nf_metrics, only: metric_type
1011
implicit none
1112

1213
private
1314
public :: loss_type
1415
public :: mse
1516
public :: quadratic
1617

17-
type, abstract :: loss_type
18+
type, extends(metric_type), abstract :: loss_type
1819
contains
19-
procedure(loss_interface), nopass, deferred :: eval
2020
procedure(loss_derivative_interface), nopass, deferred :: derivative
2121
end type loss_type
2222

2323
abstract interface
24-
pure function loss_interface(true, predicted) result(res)
25-
real, intent(in) :: true(:)
26-
real, intent(in) :: predicted(:)
27-
real :: res
28-
end function loss_interface
2924
pure function loss_derivative_interface(true, predicted) result(res)
3025
real, intent(in) :: true(:)
3126
real, intent(in) :: predicted(:)

src/nf/nf_metrics.f90

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
module nf_metrics
2+
3+
!! This module provides a collection of metric functions.
4+
5+
implicit none
6+
7+
private
8+
public :: metric_type
9+
public :: corr
10+
public :: maxabs
11+
12+
type, abstract :: metric_type
13+
contains
14+
procedure(metric_interface), nopass, deferred :: eval
15+
end type metric_type
16+
17+
abstract interface
18+
pure function metric_interface(true, predicted) result(res)
19+
real, intent(in) :: true(:)
20+
real, intent(in) :: predicted(:)
21+
real :: res
22+
end function metric_interface
23+
end interface
24+
25+
type, extends(metric_type) :: corr
26+
!! Pearson correlation
27+
contains
28+
procedure, nopass :: eval => corr_eval
29+
end type corr
30+
31+
type, extends(metric_type) :: maxabs
32+
!! Maximum absolute difference
33+
contains
34+
procedure, nopass :: eval => maxabs_eval
35+
end type maxabs
36+
37+
contains
38+
39+
pure module function corr_eval(true, predicted) result(res)
40+
!! Pearson correlation function:
41+
!!
42+
real, intent(in) :: true(:)
43+
!! True values, i.e. labels from training datasets
44+
real, intent(in) :: predicted(:)
45+
!! Values predicted by the network
46+
real :: res
47+
!! Resulting correlation value
48+
real :: m_true, m_pred
49+
50+
m_true = sum(true) / size(true)
51+
m_pred = sum(predicted) / size(predicted)
52+
53+
res = dot_product(true - m_true, predicted - m_pred) / &
54+
sqrt(sum((true - m_true)**2)*sum((predicted - m_pred)**2))
55+
56+
end function corr_eval
57+
58+
pure function maxabs_eval(true, predicted) result(res)
59+
!! Maximum absolute difference function:
60+
!!
61+
real, intent(in) :: true(:)
62+
!! True values, i.e. labels from training datasets
63+
real, intent(in) :: predicted(:)
64+
!! Values predicted by the network
65+
real :: res
66+
!! Resulting maximum absolute difference value
67+
68+
res = maxval(abs(true - predicted))
69+
70+
end function maxabs_eval
71+
72+
end module nf_metrics

src/nf/nf_network.f90

+13
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module nf_network
33
!! This module provides the network type to create new models.
44

55
use nf_layer, only: layer
6+
use nf_metrics, only: metric_type
67
use nf_loss, only: loss_type
78
use nf_optimizers, only: optimizer_base_type
89

@@ -28,13 +29,15 @@ module nf_network
2829
procedure :: train
2930
procedure :: update
3031

32+
procedure, private :: evaluate_batch_1d
3133
procedure, private :: forward_1d
3234
procedure, private :: forward_3d
3335
procedure, private :: predict_1d
3436
procedure, private :: predict_3d
3537
procedure, private :: predict_batch_1d
3638
procedure, private :: predict_batch_3d
3739

40+
generic :: evaluate => evaluate_batch_1d
3841
generic :: forward => forward_1d, forward_3d
3942
generic :: predict => predict_1d, predict_3d, predict_batch_1d, predict_batch_3d
4043

@@ -62,6 +65,16 @@ end function network_from_keras
6265

6366
end interface network
6467

68+
interface evaluate
69+
module function evaluate_batch_1d(self, input_data, output_data, metric) result(res)
70+
class(network), intent(in out) :: self
71+
real, intent(in) :: input_data(:,:)
72+
real, intent(in) :: output_data(:,:)
73+
class(metric_type), intent(in), optional :: metric
74+
real, allocatable :: res(:,:)
75+
end function evaluate_batch_1d
76+
end interface evaluate
77+
6578
interface forward
6679

6780
pure module subroutine forward_1d(self, input)

src/nf/nf_network_submodule.f90

+30
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,36 @@ pure module subroutine backward(self, output, loss)
337337
end subroutine backward
338338

339339

340+
module function evaluate_batch_1d(self, input_data, output_data, metric) result(res)
341+
class(network), intent(in out) :: self
342+
real, intent(in) :: input_data(:,:)
343+
real, intent(in) :: output_data(:,:)
344+
class(metric_type), intent(in), optional :: metric
345+
real, allocatable :: res(:,:)
346+
347+
integer :: i, n
348+
real, allocatable :: output(:,:)
349+
350+
output = self % predict(input_data)
351+
352+
n = 1
353+
if (present(metric)) n = n + 1
354+
355+
allocate(res(size(output, dim=1), n))
356+
357+
do concurrent (i = 1:size(output, dim=1))
358+
res(i,1) = self % loss % eval(output_data(i,:), output(i,:))
359+
end do
360+
361+
if (.not. present(metric)) return
362+
363+
do concurrent (i = 1:size(output, dim=1))
364+
res(i,2) = metric % eval(output_data(i,:), output(i,:))
365+
end do
366+
367+
end function evaluate_batch_1d
368+
369+
340370
pure module subroutine forward_1d(self, input)
341371
class(network), intent(in out) :: self
342372
real, intent(in) :: input(:)

test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ foreach(execid
1717
conv2d_network
1818
optimizers
1919
loss
20+
metrics
2021
)
2122
add_executable(test_${execid} test_${execid}.f90)
2223
target_link_libraries(test_${execid} PRIVATE neural-fortran h5fortran::h5fortran jsonfortran::jsonfortran ${LIBS})

test/test_metrics.f90

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
program test_metrics
2+
use iso_fortran_env, only: stderr => error_unit
3+
use nf, only: dense, input, network, sgd, mse
4+
implicit none
5+
type(network) :: net
6+
logical :: ok = .true.
7+
8+
! Minimal 2-layer network
9+
net = network([ &
10+
input(1), &
11+
dense(1) &
12+
])
13+
14+
training: block
15+
real :: x(1), y(1)
16+
real :: tolerance = 1e-3
17+
integer :: n
18+
integer, parameter :: num_iterations = 1000
19+
real :: quadratic_loss, mse_metric
20+
real, allocatable :: metrics(:,:)
21+
22+
x = [0.1234567]
23+
y = [0.7654321]
24+
25+
do n = 1, num_iterations
26+
call net % forward(x)
27+
call net % backward(y)
28+
call net % update(sgd(learning_rate=1.))
29+
if (all(abs(net % predict(x) - y) < tolerance)) exit
30+
end do
31+
32+
! Returns only one metric, based on the default loss function (quadratic).
33+
metrics = net % evaluate(reshape(x, [1, 1]), reshape(y, [1, 1]))
34+
quadratic_loss = metrics(1,1)
35+
36+
if (.not. all(shape(metrics) == [1, 1])) then
37+
write(stderr, '(a)') 'metrics array is the correct shape (1, 1).. failed'
38+
ok = .false.
39+
end if
40+
41+
! Returns two metrics, one from the loss function and another specified by the user.
42+
metrics = net % evaluate(reshape(x, [1, 1]), reshape(y, [1, 1]), metric=mse())
43+
44+
if (.not. all(shape(metrics) == [1, 2])) then
45+
write(stderr, '(a)') 'metrics array is the correct shape (1, 2).. failed'
46+
ok = .false.
47+
end if
48+
49+
mse_metric = metrics(1,2)
50+
51+
if (.not. all(metrics < 1e-5)) then
52+
write(stderr, '(a)') 'value for all metrics is expected.. failed'
53+
ok = .false.
54+
end if
55+
56+
if (.not. metrics(1,1) == quadratic_loss) then
57+
write(stderr, '(a)') 'first metric should be the same as that of the loss function.. failed'
58+
ok = .false.
59+
end if
60+
61+
end block training
62+
63+
if (ok) then
64+
print '(a)', 'test_metrics: All tests passed.'
65+
else
66+
write(stderr, '(a)') 'test_metrics: One or more tests failed.'
67+
stop 1
68+
end if
69+
70+
end program test_metrics

0 commit comments

Comments
 (0)