Skip to content

Commit e9bfbd6

Browse files
Added Momentum and Nesterov modifications (#148)
* Added Momentum and Nesterov modifications * Resolved dummy argument error * Changes in update formulas * Corrected formulae, velocity allocation changes * Added concrete implementation of RMSProp * Report RMSE every 10% of num_epochs; Fix xtest calculation * Initialize networks with same weights; larger batch; larger test array * Start putting RMS and velocity structures in place; yet to be allocated and initialized * WIP: SGD and RMSprop optimizers plumbing at the network % update level * Added get_gradients() method (draft) * Clean up formatting and docstrings * Flush gradients to zero; code compiles but segfaults * Set default value for batch_size; tests pass in debug mode but segfault in optimized mode * Update learning rates in simple and sine examples because the default changed * Added draft test suite for optimizers * Store the optimizer as a member of the network type * Don't print to stdout; indentation * Added convergence tests * Resolved comments * Clean up * Import RMSProp * Remove old code * Add optimizer support notes --------- Co-authored-by: milancurcic <[email protected]>
1 parent 6bbc28d commit e9bfbd6

16 files changed

+470
-151
lines changed

README.md

+7-3
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,26 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
1717

1818
* Training and inference of dense (fully connected) and convolutional neural
1919
networks
20+
* Stochastic gradient descent optimizers: Classic, momentum, Nesterov momentum,
21+
and RMSProp
22+
* More than a dozen activation functions and their derivatives
2023
* Loading dense and convolutional models from Keras HDF5 (.h5) files
21-
* Stochastic and mini-batch gradient descent for back-propagation
2224
* Data-based parallelism
23-
* Several activation functions and their derivatives
2425

2526
### Available layers
2627

2728
| Layer type | Constructor name | Supported input layers | Rank of output array | Forward pass | Backward pass |
2829
|------------|------------------|------------------------|----------------------|--------------|---------------|
2930
| Input | `input` | n/a | 1, 3 | n/a | n/a |
3031
| Dense (fully-connected) | `dense` | `input1d`, `flatten` | 1 |||
31-
| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 || |
32+
| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 || |
3233
| Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 |||
3334
| Flatten | `flatten` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 |||
3435
| Reshape (1-d to 3-d) | `reshape` | `input1d`, `dense`, `flatten` | 3 |||
3536

37+
**Note:** The training of convolutional layers has been discovered to be broken
38+
as of release 0.13.0. This will be fixed in a future (hopefully next) release.
39+
3640
## Getting started
3741

3842
Get the code:

example/quadratic.f90

+89-37
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,27 @@ program quadratic_fit
77
use nf_optimizers, only: sgd
88

99
implicit none
10-
type(network) :: net_sgd, net_batch_sgd, net_minibatch_sgd, net_rms_prop
10+
type(network) :: net(6)
1111

1212
! Training parameters
1313
integer, parameter :: num_epochs = 1000
1414
integer, parameter :: train_size = 1000
15-
integer, parameter :: test_size = 30
16-
integer, parameter :: batch_size = 10
15+
integer, parameter :: test_size = 100
16+
integer, parameter :: batch_size = 100
1717
real, parameter :: learning_rate = 0.01
1818
real, parameter :: decay_rate = 0.9
1919

2020
! Input and output data
2121
real, allocatable :: x(:), y(:) ! training data
2222
real, allocatable :: xtest(:), ytest(:) ! testing data
23-
real, allocatable :: ypred_sgd(:), ypred_batch_sgd(:), ypred_minibatch_sgd(:), ypred_rms_prop(:)
2423

2524
integer :: i, n
2625

2726
print '("Fitting quadratic function")'
2827
print '(60("="))'
2928

3029
allocate(xtest(test_size), ytest(test_size))
31-
xtest = [((i - 1) * 2 / test_size, i = 1, test_size)]
30+
xtest = [(real(i - 1) * 2 / test_size, i = 1, test_size)]
3231
ytest = quadratic(xtest)
3332

3433
! x and y as 1-D arrays
@@ -41,38 +40,30 @@ program quadratic_fit
4140
end do
4241
y = quadratic(x)
4342

44-
! Instantiate a separate network for each optimization method.
45-
net_sgd = network([input(1), dense(3), dense(1)])
46-
net_batch_sgd = network([input(1), dense(3), dense(1)])
47-
net_minibatch_sgd = network([input(1), dense(3), dense(1)])
48-
net_rms_prop = network([input(1), dense(3), dense(1)])
43+
! Instantiate a network and copy an instance to the rest of the array
44+
net(1) = network([input(1), dense(3), dense(1)])
45+
net(2:) = net(1)
4946

5047
! Print network info to stdout; this will be the same for all three networks.
51-
call net_sgd % print_info()
48+
call net(1) % print_info()
5249

53-
! SGD optimizer
54-
call sgd_optimizer(net_sgd, x, y, learning_rate, num_epochs)
50+
! SGD, no momentum
51+
call sgd_optimizer(net(1), x, y, xtest, ytest, learning_rate, num_epochs)
52+
53+
! SGD, momentum
54+
call sgd_optimizer(net(2), x, y, xtest, ytest, learning_rate, num_epochs, momentum=0.9)
55+
56+
! SGD, momentum with Nesterov
57+
call sgd_optimizer(net(3), x, y, xtest, ytest, learning_rate, num_epochs, momentum=0.9, nesterov=.true.)
5558

5659
! Batch SGD optimizer
57-
call batch_gd_optimizer(net_batch_sgd, x, y, learning_rate, num_epochs)
60+
call batch_gd_optimizer(net(4), x, y, xtest, ytest, learning_rate, num_epochs)
5861

5962
! Mini-batch SGD optimizer
60-
call minibatch_gd_optimizer(net_minibatch_sgd, x, y, learning_rate, num_epochs, batch_size)
63+
call minibatch_gd_optimizer(net(5), x, y, xtest, ytest, learning_rate, num_epochs, batch_size)
6164

6265
! RMSProp optimizer
63-
call rmsprop_optimizer(net_rms_prop, x, y, learning_rate, num_epochs, decay_rate)
64-
65-
! Calculate predictions on the test set
66-
ypred_sgd = [(net_sgd % predict([xtest(i)]), i = 1, test_size)]
67-
ypred_batch_sgd = [(net_batch_sgd % predict([xtest(i)]), i = 1, test_size)]
68-
ypred_minibatch_sgd = [(net_minibatch_sgd % predict([xtest(i)]), i = 1, test_size)]
69-
ypred_rms_prop = [(net_rms_prop % predict([xtest(i)]), i = 1, test_size)]
70-
71-
! Print the mean squared error
72-
print '("Stochastic gradient descent MSE:", f9.6)', sum((ypred_sgd - ytest)**2) / size(ytest)
73-
print '(" Batch gradient descent MSE: ", f9.6)', sum((ypred_batch_sgd - ytest)**2) / size(ytest)
74-
print '(" Minibatch gradient descent MSE: ", f9.6)', sum((ypred_minibatch_sgd - ytest)**2) / size(ytest)
75-
print '(" RMSProp MSE: ", f9.6)', sum((ypred_rms_prop - ytest)**2) / size(ytest)
66+
call rmsprop_optimizer(net(6), x, y, xtest, ytest, learning_rate, num_epochs, decay_rate)
7667

7768
contains
7869

@@ -82,65 +73,107 @@ real elemental function quadratic(x) result(y)
8273
y = (x**2 / 2 + x / 2 + 1) / 2
8374
end function quadratic
8475

85-
subroutine sgd_optimizer(net, x, y, learning_rate, num_epochs)
76+
subroutine sgd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, momentum, nesterov)
8677
! In the stochastic gradient descent (SGD) optimizer, we run the forward
8778
! and backward passes and update the weights for each training sample,
8879
! one at a time.
8980
type(network), intent(inout) :: net
9081
real, intent(in) :: x(:), y(:)
82+
real, intent(in) :: xtest(:), ytest(:)
9183
real, intent(in) :: learning_rate
9284
integer, intent(in) :: num_epochs
85+
real, intent(in), optional :: momentum
86+
logical, intent(in), optional :: nesterov
87+
real, allocatable :: ypred(:)
88+
real :: momentum_value
89+
logical :: nesterov_value
9390
integer :: i, n
9491

95-
print *, "Running SGD optimizer..."
92+
print '(a)', 'Stochastic gradient descent'
93+
print '(34("-"))'
94+
95+
! Set default values for momentum and nesterov
96+
if (.not. present(momentum)) then
97+
momentum_value = 0.0
98+
else
99+
momentum_value = momentum
100+
end if
101+
102+
if (.not. present(nesterov)) then
103+
nesterov_value = .false.
104+
else
105+
nesterov_value = nesterov
106+
end if
96107

97108
do n = 1, num_epochs
98109
do i = 1, size(x)
99110
call net % forward([x(i)])
100111
call net % backward([y(i)])
101-
call net % update(sgd(learning_rate=learning_rate))
112+
call net % update(sgd(learning_rate=learning_rate, momentum=momentum_value, nesterov=nesterov_value))
102113
end do
114+
115+
if (mod(n, num_epochs / 10) == 0) then
116+
ypred = [(net % predict([xtest(i)]), i = 1, size(xtest))]
117+
print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', n, num_epochs, sum((ypred - ytest)**2) / size(ytest)
118+
end if
119+
103120
end do
104121

122+
print *, ''
123+
105124
end subroutine sgd_optimizer
106125

107-
subroutine batch_gd_optimizer(net, x, y, learning_rate, num_epochs)
126+
subroutine batch_gd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs)
108127
! Like the stochastic gradient descent (SGD) optimizer, except that here we
109128
! accumulate the weight gradients for all training samples and update the
110129
! weights once per epoch.
111130
type(network), intent(inout) :: net
112131
real, intent(in) :: x(:), y(:)
132+
real, intent(in) :: xtest(:), ytest(:)
113133
real, intent(in) :: learning_rate
114134
integer, intent(in) :: num_epochs
135+
real, allocatable :: ypred(:)
115136
integer :: i, n
116137

117-
print *, "Running batch GD optimizer..."
138+
print '(a)', 'Batch gradient descent'
139+
print '(34("-"))'
118140

119141
do n = 1, num_epochs
120142
do i = 1, size(x)
121143
call net % forward([x(i)])
122144
call net % backward([y(i)])
123145
end do
124146
call net % update(sgd(learning_rate=learning_rate / size(x)))
147+
148+
if (mod(n, num_epochs / 10) == 0) then
149+
ypred = [(net % predict([xtest(i)]), i = 1, size(xtest))]
150+
print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', n, num_epochs, sum((ypred - ytest)**2) / size(ytest)
151+
end if
152+
125153
end do
126154

155+
print *, ''
156+
127157
end subroutine batch_gd_optimizer
128158

129-
subroutine minibatch_gd_optimizer(net, x, y, learning_rate, num_epochs, batch_size)
159+
subroutine minibatch_gd_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, batch_size)
130160
! Like the batch SGD optimizer, except that here we accumulate the weight
131161
! over a number of mini batches and update the weights once per mini batch.
132162
!
133163
! Note: -O3 on GFortran must be accompanied with -fno-frontend-optimize for
134164
! this subroutine to converge to a solution.
135165
type(network), intent(inout) :: net
136166
real, intent(in) :: x(:), y(:)
167+
real, intent(in) :: xtest(:), ytest(:)
137168
real, intent(in) :: learning_rate
138169
integer, intent(in) :: num_epochs, batch_size
139170
integer :: i, j, n, num_samples, num_batches, start_index, end_index
140171
real, allocatable :: batch_x(:), batch_y(:)
141172
integer, allocatable :: batch_indices(:)
173+
real, allocatable :: ypred(:)
142174

143-
print *, "Running mini-batch GD optimizer..."
175+
print '(a)', 'Minibatch gradient descent'
176+
print '(34("-"))'
144177

145178
num_samples = size(x)
146179
num_batches = num_samples / batch_size
@@ -167,17 +200,28 @@ subroutine minibatch_gd_optimizer(net, x, y, learning_rate, num_epochs, batch_si
167200

168201
call net % update(sgd(learning_rate=learning_rate / batch_size))
169202
end do
203+
204+
if (mod(n, num_epochs / 10) == 0) then
205+
ypred = [(net % predict([xtest(i)]), i = 1, size(xtest))]
206+
print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', n, num_epochs, sum((ypred - ytest)**2) / size(ytest)
207+
end if
208+
170209
end do
210+
211+
print *, ''
212+
171213
end subroutine minibatch_gd_optimizer
172214

173-
subroutine rmsprop_optimizer(net, x, y, learning_rate, num_epochs, decay_rate)
215+
subroutine rmsprop_optimizer(net, x, y, xtest, ytest, learning_rate, num_epochs, decay_rate)
174216
! RMSprop optimizer for updating weights using root mean square
175217
type(network), intent(inout) :: net
176218
real, intent(in) :: x(:), y(:)
219+
real, intent(in) :: xtest(:), ytest(:)
177220
real, intent(in) :: learning_rate, decay_rate
178221
integer, intent(in) :: num_epochs
179222
integer :: i, j, n
180223
real, parameter :: epsilon = 1e-8 ! Small constant to avoid division by zero
224+
real, allocatable :: ypred(:)
181225

182226
! Define a dedicated type to store the RMSprop gradients.
183227
! This is needed because array sizes vary between layers and we need to
@@ -191,7 +235,8 @@ subroutine rmsprop_optimizer(net, x, y, learning_rate, num_epochs, decay_rate)
191235

192236
type(rms_gradient_dense), allocatable :: rms(:)
193237

194-
print *, "Running RMSprop optimizer..."
238+
print '(a)', 'RMSProp optimizer'
239+
print '(34("-"))'
195240

196241
! Here we allocate the array or RMS gradient derived types.
197242
! We need one for each dense layer, however we will allocate it to the
@@ -237,8 +282,15 @@ subroutine rmsprop_optimizer(net, x, y, learning_rate, num_epochs, decay_rate)
237282
end select
238283
end do
239284

285+
if (mod(n, num_epochs / 10) == 0) then
286+
ypred = [(net % predict([xtest(i)]), i = 1, size(xtest))]
287+
print '("Epoch: ", i4,"/",i4,", RMSE = ", f9.6)', n, num_epochs, sum((ypred - ytest)**2) / size(ytest)
288+
end if
289+
240290
end do
241291

292+
print *, ''
293+
242294
end subroutine rmsprop_optimizer
243295

244296
subroutine shuffle(arr)

example/simple.f90

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
program simple
2-
use nf, only: dense, input, network
2+
use nf, only: dense, input, network, sgd
33
implicit none
44
type(network) :: net
55
real, allocatable :: x(:), y(:)
@@ -24,7 +24,7 @@ program simple
2424

2525
call net % forward(x)
2626
call net % backward(y)
27-
call net % update()
27+
call net % update(optimizer=sgd(learning_rate=1.))
2828

2929
if (mod(n, 50) == 0) &
3030
print '(i4,2(3x,f8.6))', n, net % predict(x)

example/sine.f90

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
program sine
2-
use nf, only: dense, input, network
2+
use nf, only: dense, input, network, sgd
33
implicit none
44
type(network) :: net
55
real :: x(1), y(1)
@@ -31,7 +31,7 @@ program sine
3131

3232
call net % forward(x)
3333
call net % backward(y)
34-
call net % update()
34+
call net % update(optimizer=sgd(learning_rate=1.))
3535

3636
if (mod(n, 10000) == 0) then
3737
ypred = [(net % predict([xtest(i)]), i = 1, test_size)]

src/nf.f90

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ module nf
55
use nf_layer_constructors, only: &
66
conv2d, dense, flatten, input, maxpool2d, reshape
77
use nf_network, only: network
8-
use nf_optimizers, only: sgd
8+
use nf_optimizers, only: sgd, rmsprop
99
use nf_activation, only: activation_function, elu, exponential, &
1010
gaussian, linear, relu, leaky_relu, &
1111
sigmoid, softmax, softplus, step, tanhf, &

src/nf/nf_conv2d_layer.f90

+13-2
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@ module nf_conv2d_layer
3030

3131
contains
3232

33-
procedure :: init
3433
procedure :: forward
3534
procedure :: backward
35+
procedure :: get_gradients
3636
procedure :: get_num_params
3737
procedure :: get_params
38+
procedure :: init
3839
procedure :: set_params
3940

4041
end type conv2d_layer
@@ -89,13 +90,23 @@ pure module function get_num_params(self) result(num_params)
8990
end function get_num_params
9091

9192
pure module function get_params(self) result(params)
92-
!! Get the parameters of the layer.
93+
!! Return the parameters (weights and biases) of this layer.
94+
!! The parameters are ordered as weights first, biases second.
9395
class(conv2d_layer), intent(in) :: self
9496
!! A `conv2d_layer` instance
9597
real, allocatable :: params(:)
9698
!! Parameters to get
9799
end function get_params
98100

101+
pure module function get_gradients(self) result(gradients)
102+
!! Return the gradients of this layer.
103+
!! The gradients are ordered as weights first, biases second.
104+
class(conv2d_layer), intent(in) :: self
105+
!! A `conv2d_layer` instance
106+
real, allocatable :: gradients(:)
107+
!! Gradients to get
108+
end function get_gradients
109+
99110
module subroutine set_params(self, params)
100111
!! Set the parameters of the layer.
101112
class(conv2d_layer), intent(in out) :: self

src/nf/nf_conv2d_layer_submodule.f90

+12
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,18 @@ pure module function get_params(self) result(params)
202202
end function get_params
203203

204204

205+
pure module function get_gradients(self) result(gradients)
206+
class(conv2d_layer), intent(in) :: self
207+
real, allocatable :: gradients(:)
208+
209+
gradients = [ &
210+
pack(self % dw, .true.), &
211+
pack(self % db, .true.) &
212+
]
213+
214+
end function get_gradients
215+
216+
205217
module subroutine set_params(self, params)
206218
class(conv2d_layer), intent(in out) :: self
207219
real, intent(in) :: params(:)

0 commit comments

Comments
 (0)