Skip to content

Commit e9af5b4

Browse files
authored
Implement the reshape layer (#97)
* Interface and constructor for the reshape3d layer * Use reshape for the constructor function and reshape3d for the internal layer implementation * Add the submodule for the concrete reshape3d_layer * Forward and backward passes for the * Add type guards for reshape layer to forward and backward subroutines * Test that the resulting shape and values of a reshape layer are correct * Bump version to 0.8.0 (unreleased) * Add reshape layer to list of features * Update CMake build for the reshape layer * Ignore submodule files * Enable reading reshape layers from Keras h5
1 parent 956c28a commit e9af5b4

17 files changed

+350
-25
lines changed

.gitignore

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
*.gz
22
*.o
33
*.mod
4+
*.smod
45
*.dat
56
*.h5
6-
build
7-
doc
7+
/build
8+
/doc

CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ add_library(neural
5858
src/nf/nf_parallel_submodule.f90
5959
src/nf/nf_random.f90
6060
src/nf/nf_random_submodule.f90
61+
src/nf/nf_reshape_layer.f90
62+
src/nf/nf_reshape_layer_submodule.f90
6163
src/nf/io/nf_io_binary.f90
6264
src/nf/io/nf_io_binary_submodule.f90
6365
src/nf/io/nf_io_hdf5.f90

README.md

+6-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
1818

1919
* Dense, fully connected neural layers
2020
* Convolutional and max-pooling layers (experimental, forward propagation only)
21-
* Flatten layers (forward and backward pass)
21+
* Flatten and reshape layers (forward and backward passes)
2222
* Loading dense and convolutional models from Keras h5 files
2323
* Stochastic and mini-batch gradient descent for back-propagation
2424
* Data-based parallelism
@@ -29,10 +29,11 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
2929
| Layer type | Constructor name | Supported input layers | Rank of output array | Forward pass | Backward pass |
3030
|------------|------------------|------------------------|----------------------|--------------|---------------|
3131
| Input (1-d and 3-d) | `input` | n/a | 1, 3 | n/a | n/a |
32-
| Dense (fully-connected) | `dense` | `input` (1-d) | 1 |||
33-
| Convolutional (2-d) | `conv2d` | `input` (3-d), `conv2d`, `maxpool2d` | 3 |||
34-
| Max-pooling (2-d) | `maxpool2d` | `input` (3-d), `conv2d`, `maxpool2d` | 3 |||
35-
| Flatten | `flatten` | `input` (3-d), `conv2d`, `maxpool2d` | 1 |||
32+
| Dense (fully-connected) | `dense` | `input1d` | 1 |||
33+
| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d` | 3 |||
34+
| Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d` | 3 |||
35+
| Flatten | `flatten` | `input3d`, `conv2d`, `maxpool2d` | 1 |||
36+
| Reshape (1-d to 3-d) | `reshape` | `input1d`, `dense`, `flatten` | 3 |||
3637

3738
## Getting started
3839

fpm.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name = "neural-fortran"
2-
version = "0.7.0"
2+
version = "0.8.0"
33
license = "MIT"
44
author = "Milan Curcic"
55
maintainer = "[email protected]"

src/nf.f90

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module nf
22
!! User API: everything an application needs to reference directly
33
use nf_datasets_mnist, only: label_digits, load_mnist
44
use nf_layer, only: layer
5-
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d
5+
use nf_layer_constructors, only: &
6+
conv2d, dense, flatten, input, maxpool2d, reshape
67
use nf_network, only: network
78
use nf_optimizers, only: sgd
89
end module nf

src/nf/nf_datasets.f90

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ module nf_datasets
1212
download_and_unpack, &
1313
keras_cnn_mnist_url, &
1414
keras_dense_mnist_url, &
15+
keras_reshape_url, &
1516
mnist_url
1617

1718
character(*), parameter :: keras_snippets_baseurl = &
@@ -22,6 +23,8 @@ module nf_datasets
2223
keras_snippets_baseurl // '/8892585/keras_cnn_mnist.tar.gz'
2324
character(*), parameter :: keras_dense_mnist_url = &
2425
keras_snippets_baseurl // '/8788739/keras_dense_mnist.tar.gz'
26+
character(*), parameter :: keras_reshape_url = &
27+
keras_snippets_baseurl // '/9667603/keras_reshape.tar.gz'
2528
character(*), parameter :: mnist_url = &
2629
neural_fortran_baseurl // '/8498876/mnist.tar.gz'
2730

src/nf/nf_keras.f90

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ module nf_keras
2929
integer, allocatable :: pool_size(:)
3030
integer, allocatable :: strides(:)
3131

32+
! Reshape
33+
integer, allocatable :: target_shape(:)
34+
3235
end type keras_layer
3336

3437
interface

src/nf/nf_keras_submodule.f90

+7
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ module function get_keras_h5_layers(filename) result(res)
8282
res(n) % pool_size = reverse(res(n) % pool_size)
8383
res(n) % strides = reverse(res(n) % strides)
8484

85+
case('Reshape')
86+
! Only read target shape
87+
call json % get(layer_config_json, &
88+
'target_shape', res(n) % target_shape, found)
89+
! Reverse to account for C -> Fortran order
90+
res(n) % target_shape = reverse(res(n) % target_shape)
91+
8592
case default
8693
error stop 'This Keras layer is not supported'
8794

src/nf/nf_layer.f90

+24-6
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,25 @@ module nf_layer
2424

2525
contains
2626

27-
procedure :: backward
2827
procedure :: forward
2928
procedure :: init
3029
procedure :: print_info
3130
procedure :: update
3231

33-
! Specific output subroutines for different array ranks,
34-
! available via generic `get_output`.
32+
! Specific subroutines for different array ranks
33+
procedure, private :: backward_1d
34+
procedure, private :: backward_3d
3535
procedure, private :: get_output_1d
3636
procedure, private :: get_output_3d
3737

38+
generic :: backward => backward_1d, backward_3d
3839
generic :: get_output => get_output_1d, get_output_3d
3940

4041
end type layer
4142

42-
interface
43+
interface backward
4344

44-
pure module subroutine backward(self, previous, gradient)
45+
pure module subroutine backward_1d(self, previous, gradient)
4546
!! Apply a backward pass on the layer.
4647
!! This changes the internal state of the layer.
4748
!! This is normally called internally by the `network % backward`
@@ -52,7 +53,24 @@ pure module subroutine backward(self, previous, gradient)
5253
!! Previous layer instance
5354
real, intent(in) :: gradient(:)
5455
!! Array of gradient values from the next layer
55-
end subroutine backward
56+
end subroutine backward_1d
57+
58+
pure module subroutine backward_3d(self, previous, gradient)
59+
!! Apply a backward pass on the layer.
60+
!! This changes the internal state of the layer.
61+
!! This is normally called internally by the `network % backward`
62+
!! method.
63+
class(layer), intent(in out) :: self
64+
!! Layer instance
65+
class(layer), intent(in) :: previous
66+
!! Previous layer instance
67+
real, intent(in) :: gradient(:,:,:)
68+
!! Array of gradient values from the next layer
69+
end subroutine backward_3d
70+
71+
end interface backward
72+
73+
interface
5674

5775
pure module subroutine forward(self, input)
5876
!! Apply a forward pass on the layer.

src/nf/nf_layer_constructors.f90

+12-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ module nf_layer_constructors
77
implicit none
88

99
private
10-
public :: conv2d, dense, flatten, input, maxpool2d
10+
public :: conv2d, dense, flatten, input, maxpool2d, reshape
1111

1212
interface input
1313

@@ -154,6 +154,17 @@ pure module function maxpool2d(pool_size, stride) result(res)
154154
!! Resulting layer instance
155155
end function maxpool2d
156156

157+
pure module function reshape(output_shape) result(res)
158+
!! Rank-1 to rank-any reshape layer constructor.
159+
!! Currently implemented is only rank-3 for the output of the reshape.
160+
!!
161+
!! This layer is for connecting 1-d inputs to conv2d or similar layers.
162+
integer, intent(in) :: output_shape(:)
163+
!! Shape of the output
164+
type(layer) :: res
165+
!! Resulting layer instance
166+
end function reshape
167+
157168
end interface
158169

159170
end module nf_layer_constructors

src/nf/nf_layer_constructors_submodule.f90

+15
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
use nf_input1d_layer, only: input1d_layer
88
use nf_input3d_layer, only: input3d_layer
99
use nf_maxpool2d_layer, only: maxpool2d_layer
10+
use nf_reshape_layer, only: reshape3d_layer
1011

1112
implicit none
1213

@@ -109,4 +110,18 @@ pure module function maxpool2d(pool_size, stride) result(res)
109110

110111
end function maxpool2d
111112

113+
pure module function reshape(output_shape) result(res)
114+
integer, intent(in) :: output_shape(:)
115+
type(layer) :: res
116+
117+
res % name = 'reshape'
118+
119+
if (size(output_shape) == 3) then
120+
allocate(res % p, source=reshape3d_layer(output_shape))
121+
else
122+
error stop 'size(output_shape) of the reshape layer must == 3'
123+
end if
124+
125+
end function reshape
126+
112127
end submodule nf_layer_constructors_submodule

src/nf/nf_layer_submodule.f90

+57-8
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66
use nf_input1d_layer, only: input1d_layer
77
use nf_input3d_layer, only: input3d_layer
88
use nf_maxpool2d_layer, only: maxpool2d_layer
9+
use nf_reshape_layer, only: reshape3d_layer
910

1011
contains
1112

12-
pure module subroutine backward(self, previous, gradient)
13+
pure module subroutine backward_1d(self, previous, gradient)
1314
implicit none
1415
class(layer), intent(in out) :: self
1516
class(layer), intent(in) :: previous
1617
real, intent(in) :: gradient(:)
1718

18-
! Backward pass currently implemented only for dense and flatten layers
19+
! Backward pass from a 1-d layer downstream currently implemented
20+
! only for dense and flatten layers
1921
select type(this_layer => self % p)
2022

2123
type is(dense_layer)
@@ -32,7 +34,7 @@ pure module subroutine backward(self, previous, gradient)
3234

3335
type is(flatten_layer)
3436

35-
! Downstream layers permitted: input3d, conv2d, maxpool2d
37+
! Upstream layers permitted: input3d, conv2d, maxpool2d
3638
select type(prev_layer => previous % p)
3739
type is(input3d_layer)
3840
call this_layer % backward(prev_layer % output, gradient)
@@ -44,7 +46,34 @@ pure module subroutine backward(self, previous, gradient)
4446

4547
end select
4648

47-
end subroutine backward
49+
end subroutine backward_1d
50+
51+
52+
pure module subroutine backward_3d(self, previous, gradient)
53+
implicit none
54+
class(layer), intent(in out) :: self
55+
class(layer), intent(in) :: previous
56+
real, intent(in) :: gradient(:,:,:)
57+
58+
! Backward pass from a 3-d layer downstream currently implemented
59+
! only for reshape3d layer
60+
select type(this_layer => self % p)
61+
62+
type is(reshape3d_layer)
63+
64+
! Upstream layers permitted: input1d, dense, flatten
65+
select type(prev_layer => previous % p)
66+
type is(input1d_layer)
67+
call this_layer % backward(prev_layer % output, gradient)
68+
type is(dense_layer)
69+
call this_layer % backward(prev_layer % output, gradient)
70+
type is(flatten_layer)
71+
call this_layer % backward(prev_layer % output, gradient)
72+
end select
73+
74+
end select
75+
76+
end subroutine backward_3d
4877

4978

5079
pure module subroutine forward(self, input)
@@ -68,38 +97,56 @@ pure module subroutine forward(self, input)
6897

6998
type is(conv2d_layer)
7099

71-
! Upstream layers permitted: input3d, conv2d, maxpool2d
100+
! Upstream layers permitted: input3d, conv2d, maxpool2d, reshape3d
72101
select type(prev_layer => input % p)
73102
type is(input3d_layer)
74103
call this_layer % forward(prev_layer % output)
75104
type is(conv2d_layer)
76105
call this_layer % forward(prev_layer % output)
77106
type is(maxpool2d_layer)
78107
call this_layer % forward(prev_layer % output)
108+
type is(reshape3d_layer)
109+
call this_layer % forward(prev_layer % output)
79110
end select
80111

81112
type is(maxpool2d_layer)
82113

83-
! Upstream layers permitted: input3d, conv2d, maxpool2d
114+
! Upstream layers permitted: input3d, conv2d, maxpool2d, reshape3d
84115
select type(prev_layer => input % p)
85116
type is(input3d_layer)
86117
call this_layer % forward(prev_layer % output)
87118
type is(conv2d_layer)
88119
call this_layer % forward(prev_layer % output)
89120
type is(maxpool2d_layer)
90121
call this_layer % forward(prev_layer % output)
122+
type is(reshape3d_layer)
123+
call this_layer % forward(prev_layer % output)
91124
end select
92125

93126
type is(flatten_layer)
94127

95-
! Upstream layers permitted: input3d, conv2d, maxpool2d
128+
! Upstream layers permitted: input3d, conv2d, maxpool2d, reshape3d
96129
select type(prev_layer => input % p)
97130
type is(input3d_layer)
98131
call this_layer % forward(prev_layer % output)
99132
type is(conv2d_layer)
100133
call this_layer % forward(prev_layer % output)
101134
type is(maxpool2d_layer)
102135
call this_layer % forward(prev_layer % output)
136+
type is(reshape3d_layer)
137+
call this_layer % forward(prev_layer % output)
138+
end select
139+
140+
type is(reshape3d_layer)
141+
142+
! Upstream layers permitted: input1d, dense, flatten
143+
select type(prev_layer => input % p)
144+
type is(input1d_layer)
145+
call this_layer % forward(prev_layer % output)
146+
type is(dense_layer)
147+
call this_layer % forward(prev_layer % output)
148+
type is(flatten_layer)
149+
call this_layer % forward(prev_layer % output)
103150
end select
104151

105152
end select
@@ -141,8 +188,10 @@ pure module subroutine get_output_3d(self, output)
141188
allocate(output, source=this_layer % output)
142189
type is(maxpool2d_layer)
143190
allocate(output, source=this_layer % output)
191+
type is(reshape3d_layer)
192+
allocate(output, source=this_layer % output)
144193
class default
145-
error stop '3-d output can only be read from an input3d, conv2d, or maxpool2d layer.'
194+
error stop '3-d output can only be read from a conv2d, input3d, maxpool2d, or reshape3d layer.'
146195

147196
end select
148197

src/nf/nf_network_submodule.f90

+9-1
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
use nf_input1d_layer, only: input1d_layer
77
use nf_input3d_layer, only: input3d_layer
88
use nf_maxpool2d_layer, only: maxpool2d_layer
9+
use nf_reshape_layer, only: reshape3d_layer
910
use nf_io_hdf5, only: get_hdf5_dataset
1011
use nf_keras, only: get_keras_h5_layers, keras_layer
1112
use nf_layer, only: layer
12-
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d
13+
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
1314
use nf_loss, only: quadratic_derivative
1415
use nf_optimizers, only: sgd
1516
use nf_parallel, only: tile_indices
@@ -117,6 +118,9 @@ module function network_from_keras(filename) result(res)
117118
keras_layers(n) % strides(1) &
118119
)
119120

121+
case('Reshape')
122+
layers(n) = reshape(keras_layers(n) % target_shape)
123+
120124
case default
121125
error stop 'This Keras layer is not supported'
122126

@@ -165,6 +169,10 @@ module function network_from_keras(filename) result(res)
165169
! Nothing to do
166170
continue
167171

172+
type is(reshape3d_layer)
173+
! Nothing to do
174+
continue
175+
168176
class default
169177
error stop 'Internal error in network_from_keras(); ' &
170178
// 'mismatch in layer types between the Keras and ' &

0 commit comments

Comments
 (0)