Skip to content

Commit 9bbd70f

Browse files
authored
CNN backward pass (#99)
* Concrete backward pass for the maxpool2d layer * Declare and allocate internal gradients; backward pass in progress * Tidy up dw calculation * 3-d activation functions for the conv2d layer * Backward pass for the conv2d layer, first implementation * Consistent notation in comments * Make maxpool2d backward pass pure * conv2d % update() method and integrate with the layer type backward pass * Begin work on CNN training test * Set the layer_shape attribute of the reshape layer * Add a TODO comment for the reshape layer error checking * Add an example for training a CNN on MNIST data * Reorganize examples * Update the README * Clean up README * Bump version to 0.9.0 * Wrap the backward pass for maxpool2d in the high-level layer backward method * Add test for maxpool2d backward pass
1 parent e9af5b4 commit 9bbd70f

20 files changed

+644
-176
lines changed

README.md

+23-17
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,22 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
1616

1717
## Features
1818

19-
* Dense, fully connected neural layers
20-
* Convolutional and max-pooling layers (experimental, forward propagation only)
21-
* Flatten and reshape layers (forward and backward passes)
22-
* Loading dense and convolutional models from Keras h5 files
19+
* Training and inference of dense (fully connected) and convolutional neural
20+
networks
21+
* Loading dense and convolutional models from Keras HDF5 (.h5) files
2322
* Stochastic and mini-batch gradient descent for back-propagation
2423
* Data-based parallelism
2524
* Several activation functions and their derivatives
2625

27-
### Available layer types
26+
### Available layers
2827

2928
| Layer type | Constructor name | Supported input layers | Rank of output array | Forward pass | Backward pass |
3029
|------------|------------------|------------------------|----------------------|--------------|---------------|
31-
| Input (1-d and 3-d) | `input` | n/a | 1, 3 | n/a | n/a |
32-
| Dense (fully-connected) | `dense` | `input1d` | 1 |||
33-
| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d` | 3 || |
34-
| Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d` | 3 || |
35-
| Flatten | `flatten` | `input3d`, `conv2d`, `maxpool2d` | 1 |||
30+
| Input | `input` | n/a | 1, 3 | n/a | n/a |
31+
| Dense (fully-connected) | `dense` | `input1d`, `flatten` | 1 |||
32+
| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 || |
33+
| Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 || |
34+
| Flatten | `flatten` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 1 |||
3635
| Reshape (1-d to 3-d) | `reshape` | `input1d`, `dense`, `flatten` | 3 |||
3736

3837
## Getting started
@@ -201,10 +200,9 @@ examples, in increasing level of complexity:
201200
1. [simple](example/simple.f90): Approximating a simple, constant data
202201
relationship
203202
2. [sine](example/sine.f90): Approximating a sine function
204-
3. [mnist](example/mnist.f90): Hand-written digit recognition using the MNIST
205-
dataset
206-
4. [cnn](example/cnn.f90): Creating and running forward a simple CNN using
207-
`input`, `conv2d`, `maxpool2d`, `flatten`, and `dense` layers.
203+
3. [dense_mnist](example/dense_mnist.f90): Hand-written digit recognition
204+
(MNIST dataset) using a dense (fully-connected) network
205+
4. [cnn_mnist](example/cnn_mnist.f90): Training a CNN on the MNIST dataset
208206
5. [dense_from_keras](example/dense_from_keras.f90): Creating a pre-trained
209207
dense model from a Keras HDF5 file and running the inference.
210208
6. [cnn_from_keras](example/cnn_from_keras.f90): Creating a pre-trained
@@ -247,10 +245,18 @@ Thanks to all open-source contributors to neural-fortran:
247245
[@rouson](https://github.com/rouson),
248246
and [@scivision](https://github.com/scivision).
249247

250-
Development of convolutional networks in neural-fortran was funded by a
251-
contract from NASA Goddard Space Flight Center to the University of Miami.
248+
Development of convolutional networks and Keras HDF5 adapters in
249+
neural-fortran was funded by a contract from NASA Goddard Space Flight Center
250+
to the University of Miami.
252251

253252
## Related projects
254253

255254
* [Fortran Keras Bridge (FKB)](https://github.com/scientific-computing/FKB)
256-
* [rte-rrtmgp](https://github.com/peterukk/rte-rrtmgp)
255+
by Jordan Ott provides a Python bridge between old (v0.1.0) neural-fortran
256+
style save files and Keras's HDF5 models. As of v0.9.0, neural-fortran
257+
implements the full feature set of FKB in pure Fortran, and in addition
258+
supports training and inference of convolutional networks.
259+
* [rte-rrtmgp-nn](https://github.com/peterukk/rte-rrtmgp-nn) by Peter Ukkonen
260+
is an implementation based on old (v0.1.0) neural-fortran which optimizes for
261+
speed and running on GPUs the memory layout and forward and backward passes of
262+
dense layers.

example/CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
foreach(execid
2-
cnn
2+
cnn_mnist
33
cnn_from_keras
4+
dense_mnist
45
dense_from_keras
5-
mnist
66
simple
77
sine
88
)

example/cnn.f90

-32
This file was deleted.

example/cnn_mnist.f90

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
program cnn_mnist
2+
3+
use nf, only: network, sgd, &
4+
input, conv2d, maxpool2d, flatten, dense, reshape, &
5+
load_mnist, label_digits
6+
7+
implicit none
8+
9+
type(network) :: net
10+
11+
real, allocatable :: training_images(:,:), training_labels(:)
12+
real, allocatable :: validation_images(:,:), validation_labels(:)
13+
real, allocatable :: testing_images(:,:), testing_labels(:)
14+
real, allocatable :: input_reshaped(:,:,:,:)
15+
real :: acc
16+
logical :: ok
17+
integer :: n
18+
integer, parameter :: num_epochs = 10
19+
20+
call load_mnist(training_images, training_labels, &
21+
validation_images, validation_labels, &
22+
testing_images, testing_labels)
23+
24+
net = network([ &
25+
input(784), &
26+
reshape([1,28,28]), &
27+
conv2d(filters=8, kernel_size=3, activation='relu'), &
28+
maxpool2d(pool_size=2), &
29+
conv2d(filters=16, kernel_size=3, activation='relu'), &
30+
maxpool2d(pool_size=2), &
31+
flatten(), &
32+
dense(10, activation='softmax') &
33+
])
34+
35+
call net % print_info()
36+
37+
epochs: do n = 1, num_epochs
38+
39+
call net % train( &
40+
training_images, &
41+
label_digits(training_labels), &
42+
batch_size=128, &
43+
epochs=1, &
44+
optimizer=sgd(learning_rate=3.) &
45+
)
46+
47+
if (this_image() == 1) &
48+
print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( &
49+
net, validation_images, label_digits(validation_labels)) * 100, ' %'
50+
51+
end do epochs
52+
53+
print '(a,f5.2,a)', 'Testing accuracy: ', &
54+
accuracy(net, testing_images, label_digits(testing_labels)) * 100, '%'
55+
56+
contains
57+
58+
real function accuracy(net, x, y)
59+
type(network), intent(in out) :: net
60+
real, intent(in) :: x(:,:), y(:,:)
61+
integer :: i, good
62+
good = 0
63+
do i = 1, size(x, dim=2)
64+
if (all(maxloc(net % predict(x(:,i))) == maxloc(y(:,i)))) then
65+
good = good + 1
66+
end if
67+
end do
68+
accuracy = real(good) / size(x, dim=2)
69+
end function accuracy
70+
71+
end program cnn_mnist

example/mnist.f90 renamed to example/dense_mnist.f90

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
program mnist
1+
program dense_mnist
22

33
use nf, only: dense, input, network, sgd, label_digits, load_mnist
44

@@ -59,4 +59,4 @@ real function accuracy(net, x, y)
5959
accuracy = real(good) / size(x, dim=2)
6060
end function accuracy
6161

62-
end program mnist
62+
end program dense_mnist

fpm.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name = "neural-fortran"
2-
version = "0.8.0"
2+
version = "0.9.0"
33
license = "MIT"
44
author = "Milan Curcic"
55
maintainer = "[email protected]"

src/nf/nf_activation.f90 renamed to src/nf/nf_activation_1d.f90

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module nf_activation
1+
module nf_activation_1d
22

33
! A collection of activation functions and their derivatives.
44

@@ -168,4 +168,4 @@ pure function tanh_prime(x) result(res)
168168
res = 1 - tanh(x)**2
169169
end function tanh_prime
170170

171-
end module nf_activation
171+
end module nf_activation_1d

src/nf/nf_activation_3d.f90

+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
module nf_activation_3d
2+
3+
! A collection of activation functions and their derivatives.
4+
5+
implicit none
6+
7+
private
8+
9+
public :: activation_function
10+
public :: elu, elu_prime
11+
public :: exponential
12+
public :: gaussian, gaussian_prime
13+
public :: relu, relu_prime
14+
public :: sigmoid, sigmoid_prime
15+
public :: softmax, softmax_prime
16+
public :: softplus, softplus_prime
17+
public :: step, step_prime
18+
public :: tanhf, tanh_prime
19+
20+
interface
21+
pure function activation_function(x) result(res)
22+
real, intent(in) :: x(:,:,:)
23+
real :: res(size(x,1),size(x,2),size(x,3))
24+
end function activation_function
25+
end interface
26+
27+
contains
28+
29+
pure function elu(x, alpha) result(res)
30+
! Exponential Linear Unit (ELU) activation function.
31+
real, intent(in) :: x(:,:,:)
32+
real, intent(in) :: alpha
33+
real :: res(size(x,1),size(x,2),size(x,3))
34+
where (x >= 0)
35+
res = x
36+
elsewhere
37+
res = alpha * (exp(x) - 1)
38+
end where
39+
end function elu
40+
41+
pure function elu_prime(x, alpha) result(res)
42+
! First derivative of the Exponential Linear Unit (ELU)
43+
! activation function.
44+
real, intent(in) :: x(:,:,:)
45+
real, intent(in) :: alpha
46+
real :: res(size(x,1),size(x,2),size(x,3))
47+
where (x >= 0)
48+
res = 1
49+
elsewhere
50+
res = alpha * exp(x)
51+
end where
52+
end function elu_prime
53+
54+
pure function exponential(x) result(res)
55+
! Exponential activation function.
56+
real, intent(in) :: x(:,:,:)
57+
real :: res(size(x,1),size(x,2),size(x,3))
58+
res = exp(x)
59+
end function exponential
60+
61+
pure function gaussian(x) result(res)
62+
! Gaussian activation function.
63+
real, intent(in) :: x(:,:,:)
64+
real :: res(size(x,1),size(x,2),size(x,3))
65+
res = exp(-x**2)
66+
end function gaussian
67+
68+
pure function gaussian_prime(x) result(res)
69+
! First derivative of the Gaussian activation function.
70+
real, intent(in) :: x(:,:,:)
71+
real :: res(size(x,1),size(x,2),size(x,3))
72+
res = -2 * x * gaussian(x)
73+
end function gaussian_prime
74+
75+
pure function relu(x) result(res)
76+
!! Rectified Linear Unit (ReLU) activation function.
77+
real, intent(in) :: x(:,:,:)
78+
real :: res(size(x,1),size(x,2),size(x,3))
79+
res = max(0., x)
80+
end function relu
81+
82+
pure function relu_prime(x) result(res)
83+
! First derivative of the Rectified Linear Unit (ReLU) activation function.
84+
real, intent(in) :: x(:,:,:)
85+
real :: res(size(x,1),size(x,2),size(x,3))
86+
where (x > 0)
87+
res = 1
88+
elsewhere
89+
res = 0
90+
end where
91+
end function relu_prime
92+
93+
pure function sigmoid(x) result(res)
94+
! Sigmoid activation function.
95+
real, intent(in) :: x(:,:,:)
96+
real :: res(size(x,1),size(x,2),size(x,3))
97+
res = 1 / (1 + exp(-x))
98+
endfunction sigmoid
99+
100+
pure function sigmoid_prime(x) result(res)
101+
! First derivative of the sigmoid activation function.
102+
real, intent(in) :: x(:,:,:)
103+
real :: res(size(x,1),size(x,2),size(x,3))
104+
res = sigmoid(x) * (1 - sigmoid(x))
105+
end function sigmoid_prime
106+
107+
pure function softmax(x) result(res)
108+
!! Softmax activation function
109+
real, intent(in) :: x(:,:,:)
110+
real :: res(size(x,1),size(x,2),size(x,3))
111+
res = exp(x - maxval(x))
112+
res = res / sum(res)
113+
end function softmax
114+
115+
pure function softmax_prime(x) result(res)
116+
!! Derivative of the softmax activation function.
117+
real, intent(in) :: x(:,:,:)
118+
real :: res(size(x,1),size(x,2),size(x,3))
119+
res = softmax(x) * (1 - softmax(x))
120+
end function softmax_prime
121+
122+
pure function softplus(x) result(res)
123+
! Softplus activation function.
124+
real, intent(in) :: x(:,:,:)
125+
real :: res(size(x,1),size(x,2),size(x,3))
126+
res = log(exp(x) + 1)
127+
end function softplus
128+
129+
pure function softplus_prime(x) result(res)
130+
! First derivative of the softplus activation function.
131+
real, intent(in) :: x(:,:,:)
132+
real :: res(size(x,1),size(x,2),size(x,3))
133+
res = exp(x) / (exp(x) + 1)
134+
end function softplus_prime
135+
136+
pure function step(x) result(res)
137+
! Step activation function.
138+
real, intent(in) :: x(:,:,:)
139+
real :: res(size(x,1),size(x,2),size(x,3))
140+
where (x > 0)
141+
res = 1
142+
elsewhere
143+
res = 0
144+
end where
145+
end function step
146+
147+
pure function step_prime(x) result(res)
148+
! First derivative of the step activation function.
149+
real, intent(in) :: x(:,:,:)
150+
real :: res(size(x,1),size(x,2),size(x,3))
151+
res = 0
152+
end function step_prime
153+
154+
pure function tanhf(x) result(res)
155+
! Tangent hyperbolic activation function.
156+
! Same as the intrinsic tanh, but must be
157+
! defined here so that we can use procedure
158+
! pointer with it.
159+
real, intent(in) :: x(:,:,:)
160+
real :: res(size(x,1),size(x,2),size(x,3))
161+
res = tanh(x)
162+
end function tanhf
163+
164+
pure function tanh_prime(x) result(res)
165+
! First derivative of the tanh activation function.
166+
real, intent(in) :: x(:,:,:)
167+
real :: res(size(x,1),size(x,2),size(x,3))
168+
res = 1 - tanh(x)**2
169+
end function tanh_prime
170+
171+
end module nf_activation_3d

0 commit comments

Comments
 (0)