Skip to content

Commit f7b6006

Browse files
jvdp1Vandenplas, Jeremiemilancurcic
authored
Addition of the Loss derived type and of the MSE loss function (#175)
* Addition of the abstract DT loss_type and of the DT quadratic * Support of the loss_type for the derivative loss function * Addition of the MSE loss function * add documentation * Test program placeholder * Add loss test to CMake config * Minimal test for expected values * Bump version and copyright years --------- Co-authored-by: Vandenplas, Jeremie <[email protected]> Co-authored-by: milancurcic <[email protected]>
1 parent cf47114 commit f7b6006

9 files changed

+185
-21
lines changed

LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2018-2023 neural-fortran contributors
3+
Copyright (c) 2018-2024 neural-fortran contributors
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

fpm.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
name = "neural-fortran"
2-
version = "0.15.1"
2+
version = "0.16.0"
33
license = "MIT"
44
author = "Milan Curcic"
55
maintainer = "[email protected]"
6-
copyright = "Copyright 2018-2023, neural-fortran contributors"
6+
copyright = "Copyright 2018-2024, neural-fortran contributors"
77

88
[build]
99
external-modules = "hdf5"

src/nf.f90

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ module nf
44
use nf_layer, only: layer
55
use nf_layer_constructors, only: &
66
conv2d, dense, flatten, input, maxpool2d, reshape
7+
use nf_loss, only: mse, quadratic
78
use nf_network, only: network
89
use nf_optimizers, only: sgd, rmsprop, adam, adagrad
910
use nf_activation, only: activation_function, elu, exponential, &

src/nf/nf_loss.f90

+72-8
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,92 @@
11
module nf_loss
22

3-
!! This module will eventually provide a collection of loss functions and
4-
!! their derivatives. For the time being it provides only the quadratic
5-
!! function.
3+
!! This module provides a collection of loss functions and their derivatives.
4+
!! The implementation is based on an abstract loss derived type
5+
!! which has the required eval and derivative methods.
6+
!! An implementation of a new loss type thus requires writing a concrete
7+
!! loss type that extends the abstract loss derived type, and that
8+
!! implements concrete eval and derivative methods that accept vectors.
69

710
implicit none
811

912
private
10-
public :: quadratic, quadratic_derivative
13+
public :: loss_type
14+
public :: mse
15+
public :: quadratic
16+
17+
type, abstract :: loss_type
18+
contains
19+
procedure(loss_interface), nopass, deferred :: eval
20+
procedure(loss_derivative_interface), nopass, deferred :: derivative
21+
end type loss_type
22+
23+
abstract interface
24+
pure function loss_interface(true, predicted) result(res)
25+
real, intent(in) :: true(:)
26+
real, intent(in) :: predicted(:)
27+
real :: res
28+
end function loss_interface
29+
pure function loss_derivative_interface(true, predicted) result(res)
30+
real, intent(in) :: true(:)
31+
real, intent(in) :: predicted(:)
32+
real :: res(size(true))
33+
end function loss_derivative_interface
34+
end interface
35+
36+
type, extends(loss_type) :: mse
37+
!! Mean Square Error loss function
38+
contains
39+
procedure, nopass :: eval => mse_eval
40+
procedure, nopass :: derivative => mse_derivative
41+
end type mse
42+
43+
type, extends(loss_type) :: quadratic
44+
!! Quadratic loss function
45+
contains
46+
procedure, nopass :: eval => quadratic_eval
47+
procedure, nopass :: derivative => quadratic_derivative
48+
end type quadratic
1149

1250
interface
1351

14-
pure module function quadratic(true, predicted) result(res)
15-
!! Quadratic loss function:
52+
pure module function mse_eval(true, predicted) result(res)
53+
!! Mean Square Error loss function:
54+
!!
55+
!! L = sum((predicted - true)**2) / size(true)
56+
!!
57+
real, intent(in) :: true(:)
58+
!! True values, i.e. labels from training datasets
59+
real, intent(in) :: predicted(:)
60+
!! Values predicted by the network
61+
real :: res
62+
!! Resulting loss value
63+
end function mse_eval
64+
65+
pure module function mse_derivative(true, predicted) result(res)
66+
!! First derivative of the Mean Square Error loss function:
1667
!!
17-
!! L = (predicted - true)**2 / 2
68+
!! L = 2 * (predicted - true) / size(true)
1869
!!
1970
real, intent(in) :: true(:)
2071
!! True values, i.e. labels from training datasets
2172
real, intent(in) :: predicted(:)
2273
!! Values predicted by the network
2374
real :: res(size(true))
2475
!! Resulting loss values
25-
end function quadratic
76+
end function mse_derivative
77+
78+
pure module function quadratic_eval(true, predicted) result(res)
79+
!! Quadratic loss function:
80+
!!
81+
!! L = sum((predicted - true)**2) / 2
82+
!!
83+
real, intent(in) :: true(:)
84+
!! True values, i.e. labels from training datasets
85+
real, intent(in) :: predicted(:)
86+
!! Values predicted by the network
87+
real :: res
88+
!! Resulting loss value
89+
end function quadratic_eval
2690

2791
pure module function quadratic_derivative(true, predicted) result(res)
2892
!! First derivative of the quadratic loss function:

src/nf/nf_loss_submodule.f90

+18-4
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
contains
66

7-
pure module function quadratic(true, predicted) result(res)
7+
pure module function quadratic_eval(true, predicted) result(res)
88
real, intent(in) :: true(:)
99
real, intent(in) :: predicted(:)
10-
real :: res(size(true))
11-
res = (predicted - true)**2 / 2
12-
end function quadratic
10+
real :: res
11+
res = sum((predicted - true)**2) / 2
12+
end function quadratic_eval
1313

1414
pure module function quadratic_derivative(true, predicted) result(res)
1515
real, intent(in) :: true(:)
@@ -18,4 +18,18 @@ pure module function quadratic_derivative(true, predicted) result(res)
1818
res = predicted - true
1919
end function quadratic_derivative
2020

21+
pure module function mse_eval(true, predicted) result(res)
22+
real, intent(in) :: true(:)
23+
real, intent(in) :: predicted(:)
24+
real :: res
25+
res = sum((predicted - true)**2) / size(true)
26+
end function mse_eval
27+
28+
pure module function mse_derivative(true, predicted) result(res)
29+
real, intent(in) :: true(:)
30+
real, intent(in) :: predicted(:)
31+
real :: res(size(true))
32+
res = 2 * (predicted - true) / size(true)
33+
end function mse_derivative
34+
2135
end submodule nf_loss_submodule

src/nf/nf_network.f90

+8-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module nf_network
33
!! This module provides the network type to create new models.
44

55
use nf_layer, only: layer
6+
use nf_loss, only: loss_type
67
use nf_optimizers, only: optimizer_base_type
78

89
implicit none
@@ -13,6 +14,7 @@ module nf_network
1314
type :: network
1415

1516
type(layer), allocatable :: layers(:)
17+
class(loss_type), allocatable :: loss
1618
class(optimizer_base_type), allocatable :: optimizer
1719

1820
contains
@@ -138,7 +140,7 @@ end function predict_batch_3d
138140

139141
interface
140142

141-
pure module subroutine backward(self, output)
143+
pure module subroutine backward(self, output, loss)
142144
!! Apply one backward pass through the network.
143145
!! This changes the state of layers on the network.
144146
!! Typically used only internally from the `train` method,
@@ -147,6 +149,8 @@ pure module subroutine backward(self, output)
147149
!! Network instance
148150
real, intent(in) :: output(:)
149151
!! Output data
152+
class(loss_type), intent(in), optional :: loss
153+
!! Loss instance to use. If not provided, the default is quadratic().
150154
end subroutine backward
151155

152156
pure module integer function get_num_params(self)
@@ -185,7 +189,7 @@ module subroutine print_info(self)
185189
end subroutine print_info
186190

187191
module subroutine train(self, input_data, output_data, batch_size, &
188-
epochs, optimizer)
192+
epochs, optimizer, loss)
189193
class(network), intent(in out) :: self
190194
!! Network instance
191195
real, intent(in) :: input_data(:,:)
@@ -204,6 +208,8 @@ module subroutine train(self, input_data, output_data, batch_size, &
204208
!! Number of epochs to run
205209
class(optimizer_base_type), intent(in), optional :: optimizer
206210
!! Optimizer instance to use. If not provided, the default is sgd().
211+
class(loss_type), intent(in), optional :: loss
212+
!! Loss instance to use. If not provided, the default is quadratic().
207213
end subroutine train
208214

209215
module subroutine update(self, optimizer, batch_size)

src/nf/nf_network_submodule.f90

+29-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
use nf_keras, only: get_keras_h5_layers, keras_layer
1212
use nf_layer, only: layer
1313
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
14-
use nf_loss, only: quadratic_derivative
14+
use nf_loss, only: quadratic
1515
use nf_optimizers, only: optimizer_base_type, sgd
1616
use nf_parallel, only: tile_indices
1717
use nf_activation, only: activation_function, &
@@ -280,11 +280,27 @@ pure function get_activation_by_name(activation_name) result(res)
280280

281281
end function get_activation_by_name
282282

283-
pure module subroutine backward(self, output)
283+
pure module subroutine backward(self, output, loss)
284284
class(network), intent(in out) :: self
285285
real, intent(in) :: output(:)
286+
class(loss_type), intent(in), optional :: loss
286287
integer :: n, num_layers
287288

289+
! Passing the loss instance is optional. If not provided, and if the
290+
! loss instance has not already been set, we default to the default quadratic. The
291+
! instantiation and initialization below of the loss instance is normally done
292+
! at the beginning of the network % train() method. However, if the user
293+
! wants to call network % backward() directly, for example if they use their
294+
! own custom mini-batching routine, we initialize the loss instance here as
295+
! well. If it's initialized already, this step is a cheap no-op.
296+
if (.not. allocated(self % loss)) then
297+
if (present(loss)) then
298+
self % loss = loss
299+
else
300+
self % loss = quadratic()
301+
end if
302+
end if
303+
288304
num_layers = size(self % layers)
289305

290306
! Iterate backward over layers, from the output layer
@@ -297,7 +313,7 @@ pure module subroutine backward(self, output)
297313
type is(dense_layer)
298314
call self % layers(n) % backward( &
299315
self % layers(n - 1), &
300-
quadratic_derivative(output, this_layer % output) &
316+
self % loss % derivative(output, this_layer % output) &
301317
)
302318
end select
303319
else
@@ -542,13 +558,14 @@ end subroutine set_params
542558

543559

544560
module subroutine train(self, input_data, output_data, batch_size, &
545-
epochs, optimizer)
561+
epochs, optimizer, loss)
546562
class(network), intent(in out) :: self
547563
real, intent(in) :: input_data(:,:)
548564
real, intent(in) :: output_data(:,:)
549565
integer, intent(in) :: batch_size
550566
integer, intent(in) :: epochs
551567
class(optimizer_base_type), intent(in), optional :: optimizer
568+
class(loss_type), intent(in), optional :: loss
552569
class(optimizer_base_type), allocatable :: optimizer_
553570

554571
real :: pos
@@ -567,6 +584,14 @@ module subroutine train(self, input_data, output_data, batch_size, &
567584

568585
call self % optimizer % init(self % get_num_params())
569586

587+
! Passing the loss instance is optional.
588+
! If not provided, we default to quadratic().
589+
if (present(loss)) then
590+
self % loss = loss
591+
else
592+
self % loss = quadratic()
593+
end if
594+
570595
dataset_size = size(output_data, dim=2)
571596

572597
epoch_loop: do n = 1, epochs

test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ foreach(execid
1616
cnn_from_keras
1717
conv2d_network
1818
optimizers
19+
loss
1920
)
2021
add_executable(test_${execid} test_${execid}.f90)
2122
target_link_libraries(test_${execid} PRIVATE neural h5fortran::h5fortran jsonfortran::jsonfortran ${LIBS})

test/test_loss.f90

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
program test_loss
2+
3+
use iso_fortran_env, only: stderr => error_unit
4+
use nf, only: mse, quadratic
5+
6+
implicit none
7+
8+
logical :: ok = .true.
9+
10+
block
11+
12+
type(mse) :: loss
13+
real :: true(2) = [1., 2.]
14+
real :: pred(2) = [3., 4.]
15+
16+
if (.not. loss % eval(true, pred) == 4) then
17+
write(stderr, '(a)') 'expected output of mse % eval().. failed'
18+
ok = .false.
19+
end if
20+
21+
if (.not. all(loss % derivative(true, pred) == [2, 2])) then
22+
write(stderr, '(a)') 'expected output of mse % derivative().. failed'
23+
ok = .false.
24+
end if
25+
26+
end block
27+
28+
block
29+
30+
type(quadratic) :: loss
31+
real :: true(4) = [1., 2., 3., 4.]
32+
real :: pred(4) = [3., 4., 5., 6.]
33+
34+
if (.not. loss % eval(true, pred) == 8) then
35+
write(stderr, '(a)') 'expected output of quadratic % eval().. failed'
36+
ok = .false.
37+
end if
38+
39+
if (.not. all(loss % derivative(true, pred) == [2, 2, 2, 2])) then
40+
write(stderr, '(a)') 'expected output of quadratic % derivative().. failed'
41+
ok = .false.
42+
end if
43+
44+
end block
45+
46+
if (ok) then
47+
print '(a)', 'test_loss: All tests passed.'
48+
else
49+
write(stderr, '(a)') 'test_loss: One or more tests failed.'
50+
stop 1
51+
end if
52+
53+
end program test_loss

0 commit comments

Comments
 (0)