Skip to content

Commit c78c078

Browse files
authored
Merge pull request #75 from milancurcic/flatten-layer
Implement a flatten layer
2 parents 6fda1a5 + fa73fb7 commit c78c078

13 files changed

+439
-80
lines changed

CMakeLists.txt

+4-2
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ add_library(neural
7474
src/nf/nf_datasets_mnist_submodule.f90
7575
src/nf/nf_dense_layer.f90
7676
src/nf/nf_dense_layer_submodule.f90
77+
src/nf/nf_flatten_layer.f90
78+
src/nf/nf_flatten_layer_submodule.f90
7779
src/nf/nf_input1d_layer.f90
7880
src/nf/nf_input1d_layer_submodule.f90
7981
src/nf/nf_input3d_layer.f90
@@ -102,13 +104,13 @@ string(REGEX REPLACE "^ | $" "" LIBS "${LIBS}")
102104

103105
# tests
104106
enable_testing()
105-
foreach(execid input1d_layer input3d_layer dense_layer conv2d_layer maxpool2d_layer dense_network conv2d_network)
107+
foreach(execid input1d_layer input3d_layer dense_layer conv2d_layer maxpool2d_layer flatten_layer dense_network conv2d_network)
106108
add_executable(test_${execid} test/test_${execid}.f90)
107109
target_link_libraries(test_${execid} neural ${LIBS})
108110
add_test(test_${execid} bin/test_${execid})
109111
endforeach()
110112

111-
foreach(execid mnist simple sine)
113+
foreach(execid cnn mnist simple sine)
112114
add_executable(${execid} example/${execid}.f90)
113115
target_link_libraries(${execid} neural ${LIBS})
114116
endforeach()

README.md

+7-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
3030
| Dense (fully-connected) | `dense` | `input` (1-d) | 1 |||
3131
| Convolutional (2-d) | `conv2d` | `input` (3-d), `conv2d`, `maxpool2d` | 3 |||
3232
| Max-pooling (2-d) | `maxpool2d` | `input` (3-d), `conv2d`, `maxpool2d` | 3 |||
33+
| Flatten | `flatten` | `input` (3-d), `conv2d`, `maxpool2d` | 1 |||
3334

3435
## Getting started
3536

@@ -172,9 +173,13 @@ to run the tests.
172173
The easiest way to get a sense of how to use neural-fortran is to look at
173174
examples, in increasing level of complexity:
174175

175-
1. [simple](example/simple.f90): Approximating a simple, constant data relationship
176+
1. [simple](example/simple.f90): Approximating a simple, constant data
177+
relationship
176178
2. [sine](example/sine.f90): Approximating a sine function
177-
3. [mnist](example/mnist.f90): Hand-written digit recognition using the MNIST dataset
179+
3. [mnist](example/mnist.f90): Hand-written digit recognition using the MNIST
180+
dataset
181+
4. [cnn](example/cnn.f90): Creating and running forward a simple CNN using
182+
`input`, `conv2d`, `maxpool2d`, `flatten`, and `dense` layers.
178183

179184
The examples also show you the extent of the public API that's meant to be
180185
used in applications, i.e. anything from the `nf` module.

example/cnn.f90

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
program cnn
2+
3+
use nf, only: conv2d, dense, flatten, input, maxpool2d, network
4+
5+
implicit none
6+
type(network) :: net
7+
real, allocatable :: x(:,:,:)
8+
integer :: n
9+
10+
print '("Creating a CNN and doing a forward pass")'
11+
print '("(backward pass not implemented yet)")'
12+
print '(60("="))'
13+
14+
net = network([ &
15+
input([3, 32, 32]), &
16+
conv2d(filters=16, kernel_size=3, activation='relu'), & ! (16, 30, 30)
17+
maxpool2d(pool_size=2), & ! (16, 15, 15)
18+
conv2d(filters=32, kernel_size=3, activation='relu'), & ! (32, 13, 13)
19+
maxpool2d(pool_size=2), & ! (32, 6, 6)
20+
flatten(), &
21+
dense(10) &
22+
])
23+
24+
! Print a network summary to the screen
25+
call net % print_info()
26+
27+
allocate(x(3,32,32))
28+
call random_number(x)
29+
30+
print *, 'Output:', net % output(x)
31+
32+
end program cnn

src/nf.f90

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module nf
22
!! User API: everything an application needs to reference directly
33
use nf_datasets_mnist, only: label_digits, load_mnist
44
use nf_layer, only: layer
5-
use nf_layer_constructors, only: conv2d, dense, input, maxpool2d
5+
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d
66
use nf_network, only: network
77
use nf_optimizers, only: sgd
88
end module nf

src/nf/nf_flatten_layer.f90

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
module nf_flatten_layer
2+
3+
!! This module provides the concrete flatten layer type.
4+
!! It is used internally by the layer type.
5+
!! It is not intended to be used directly by the user.
6+
7+
use nf_base_layer, only: base_layer
8+
9+
implicit none
10+
11+
private
12+
public :: flatten_layer
13+
14+
type, extends(base_layer) :: flatten_layer
15+
16+
!! Concrete implementation of a flatten (3-d to 1-d) layer.
17+
18+
integer, allocatable :: input_shape(:)
19+
integer :: output_size
20+
21+
real, allocatable :: gradient(:,:,:)
22+
real, allocatable :: output(:)
23+
24+
contains
25+
26+
procedure :: backward
27+
procedure :: forward
28+
procedure :: init
29+
30+
end type flatten_layer
31+
32+
interface flatten_layer
33+
elemental module function flatten_layer_cons() result(res)
34+
!! This function returns the `flatten_layer` instance.
35+
type(flatten_layer) :: res
36+
!! `flatten_layer` instance
37+
end function flatten_layer_cons
38+
end interface flatten_layer
39+
40+
interface
41+
42+
pure module subroutine backward(self, input, gradient)
43+
!! Apply the backward pass to the flatten layer.
44+
!! This is a reshape operation from 1-d gradient to 3-d input.
45+
class(flatten_layer), intent(in out) :: self
46+
!! Flatten layer instance
47+
real, intent(in) :: input(:,:,:)
48+
!! Input from the previous layer
49+
real, intent(in) :: gradient(:)
50+
!! Gradient from the next layer
51+
end subroutine backward
52+
53+
pure module subroutine forward(self, input)
54+
!! Propagate forward the layer.
55+
!! Calling this subroutine updates the values of a few data components
56+
!! of `flatten_layer` that are needed for the backward pass.
57+
class(flatten_layer), intent(in out) :: self
58+
!! Dense layer instance
59+
real, intent(in) :: input(:,:,:)
60+
!! Input from the previous layer
61+
end subroutine forward
62+
63+
module subroutine init(self, input_shape)
64+
!! Initialize the layer data structures.
65+
!!
66+
!! This is a deferred procedure from the `base_layer` abstract type.
67+
class(flatten_layer), intent(in out) :: self
68+
!! Dense layer instance
69+
integer, intent(in) :: input_shape(:)
70+
!! Shape of the input layer
71+
end subroutine init
72+
73+
end interface
74+
75+
end module nf_flatten_layer

src/nf/nf_flatten_layer_submodule.f90

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
submodule(nf_flatten_layer) nf_flatten_layer_submodule
2+
3+
!! This module provides the concrete flatten layer type.
4+
!! It is used internally by the layer type.
5+
!! It is not intended to be used directly by the user.
6+
7+
use nf_base_layer, only: base_layer
8+
9+
implicit none
10+
11+
contains
12+
13+
elemental module function flatten_layer_cons() result(res)
14+
type(flatten_layer) :: res
15+
end function flatten_layer_cons
16+
17+
18+
pure module subroutine backward(self, input, gradient)
19+
class(flatten_layer), intent(in out) :: self
20+
real, intent(in) :: input(:,:,:)
21+
real, intent(in) :: gradient(:)
22+
self % gradient = reshape(gradient, shape(input))
23+
end subroutine backward
24+
25+
26+
pure module subroutine forward(self, input)
27+
class(flatten_layer), intent(in out) :: self
28+
real, intent(in) :: input(:,:,:)
29+
self % output = pack(input, .true.)
30+
end subroutine forward
31+
32+
33+
module subroutine init(self, input_shape)
34+
class(flatten_layer), intent(in out) :: self
35+
integer, intent(in) :: input_shape(:)
36+
37+
self % input_shape = input_shape
38+
self % output_size = product(input_shape)
39+
40+
allocate(self % gradient(input_shape(1), input_shape(2), input_shape(3)))
41+
self % gradient = 0
42+
43+
allocate(self % output(self % output_size))
44+
self % output = 0
45+
46+
end subroutine init
47+
48+
end submodule nf_flatten_layer_submodule

src/nf/nf_layer_constructors.f90

+22-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ module nf_layer_constructors
77
implicit none
88

99
private
10-
public :: conv2d, dense, input, maxpool2d
10+
public :: conv2d, dense, flatten, input, maxpool2d
1111

1212
interface input
1313

@@ -84,6 +84,27 @@ pure module function dense(layer_size, activation) result(res)
8484
!! Resulting layer instance
8585
end function dense
8686

87+
pure module function flatten() result(res)
88+
!! Flatten (3-d -> 1-d) layer constructor.
89+
!!
90+
!! Use this layer to chain layers with 3-d outputs to layers with 1-d
91+
!! inputs. For example, to chain a `conv2d` or a `maxpool2d` layer
92+
!! with a `dense` layer for a CNN for classification, place a `flatten`
93+
!! layer between them.
94+
!!
95+
!! A flatten layer must not be the first layer in the network.
96+
!!
97+
!! Example:
98+
!!
99+
!! ```
100+
!! use nf, only :: flatten, layer
101+
!! type(layer) :: flatten_layer
102+
!! flatten_layer = flatten()
103+
!! ```
104+
type(layer) :: res
105+
!! Resulting layer instance
106+
end function flatten
107+
87108
pure module function conv2d(filters, kernel_size, activation) result(res)
88109
!! 2-d convolutional layer constructor.
89110
!!

src/nf/nf_layer_constructors_submodule.f90

+40-33
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
use nf_layer, only: layer
44
use nf_conv2d_layer, only: conv2d_layer
55
use nf_dense_layer, only: dense_layer
6+
use nf_flatten_layer, only: flatten_layer
67
use nf_input1d_layer, only: input1d_layer
78
use nf_input3d_layer, only: input3d_layer
89
use nf_maxpool2d_layer, only: maxpool2d_layer
@@ -11,26 +12,26 @@
1112

1213
contains
1314

14-
pure module function input1d(layer_size) result(res)
15-
integer, intent(in) :: layer_size
15+
pure module function conv2d(filters, kernel_size, activation) result(res)
16+
integer, intent(in) :: filters
17+
integer, intent(in) :: kernel_size
18+
character(*), intent(in), optional :: activation
1619
type(layer) :: res
17-
res % name = 'input'
18-
res % layer_shape = [layer_size]
19-
res % input_layer_shape = [integer ::]
20-
allocate(res % p, source=input1d_layer(layer_size))
21-
res % initialized = .true.
22-
end function input1d
2320

21+
res % name = 'conv2d'
2422

25-
pure module function input3d(layer_shape) result(res)
26-
integer, intent(in) :: layer_shape(3)
27-
type(layer) :: res
28-
res % name = 'input'
29-
res % layer_shape = layer_shape
30-
res % input_layer_shape = [integer ::]
31-
allocate(res % p, source=input3d_layer(layer_shape))
32-
res % initialized = .true.
33-
end function input3d
23+
if (present(activation)) then
24+
res % activation = activation
25+
else
26+
res % activation = 'sigmoid'
27+
end if
28+
29+
allocate( &
30+
res % p, &
31+
source=conv2d_layer(filters, kernel_size, res % activation) &
32+
)
33+
34+
end function conv2d
3435

3536

3637
pure module function dense(layer_size, activation) result(res)
@@ -52,27 +53,33 @@ pure module function dense(layer_size, activation) result(res)
5253
end function dense
5354

5455

55-
pure module function conv2d(filters, kernel_size, activation) result(res)
56-
integer, intent(in) :: filters
57-
integer, intent(in) :: kernel_size
58-
character(*), intent(in), optional :: activation
56+
pure module function flatten() result(res)
5957
type(layer) :: res
58+
res % name = 'flatten'
59+
allocate(res % p, source=flatten_layer())
60+
end function flatten
6061

61-
res % name = 'conv2d'
6262

63-
if (present(activation)) then
64-
res % activation = activation
65-
else
66-
res % activation = 'sigmoid'
67-
end if
68-
69-
allocate( &
70-
res % p, &
71-
source=conv2d_layer(filters, kernel_size, res % activation) &
72-
)
63+
pure module function input1d(layer_size) result(res)
64+
integer, intent(in) :: layer_size
65+
type(layer) :: res
66+
res % name = 'input'
67+
res % layer_shape = [layer_size]
68+
res % input_layer_shape = [integer ::]
69+
allocate(res % p, source=input1d_layer(layer_size))
70+
res % initialized = .true.
71+
end function input1d
7372

74-
end function conv2d
7573

74+
pure module function input3d(layer_shape) result(res)
75+
integer, intent(in) :: layer_shape(3)
76+
type(layer) :: res
77+
res % name = 'input'
78+
res % layer_shape = layer_shape
79+
res % input_layer_shape = [integer ::]
80+
allocate(res % p, source=input3d_layer(layer_shape))
81+
res % initialized = .true.
82+
end function input3d
7683

7784
pure module function maxpool2d(pool_size, stride) result(res)
7885
integer, intent(in) :: pool_size

0 commit comments

Comments
 (0)