6
6
use nf_input1d_layer, only: input1d_layer
7
7
use nf_input3d_layer, only: input3d_layer
8
8
use nf_maxpool2d_layer, only: maxpool2d_layer
9
+ use nf_reshape_layer, only: reshape3d_layer
9
10
10
11
contains
11
12
12
- pure module subroutine backward (self, previous, gradient)
13
+ pure module subroutine backward_1d (self, previous, gradient)
13
14
implicit none
14
15
class(layer), intent (in out ) :: self
15
16
class(layer), intent (in ) :: previous
16
17
real , intent (in ) :: gradient(:)
17
18
18
- ! Backward pass currently implemented only for dense and flatten layers
19
+ ! Backward pass from a 1-d layer downstream currently implemented
20
+ ! only for dense and flatten layers
19
21
select type (this_layer = > self % p)
20
22
21
23
type is (dense_layer)
@@ -32,7 +34,7 @@ pure module subroutine backward(self, previous, gradient)
32
34
33
35
type is (flatten_layer)
34
36
35
- ! Downstream layers permitted: input3d, conv2d, maxpool2d
37
+ ! Upstream layers permitted: input3d, conv2d, maxpool2d
36
38
select type (prev_layer = > previous % p)
37
39
type is (input3d_layer)
38
40
call this_layer % backward(prev_layer % output, gradient)
@@ -44,7 +46,34 @@ pure module subroutine backward(self, previous, gradient)
44
46
45
47
end select
46
48
47
- end subroutine backward
49
+ end subroutine backward_1d
50
+
51
+
52
+ pure module subroutine backward_3d(self, previous, gradient)
53
+ implicit none
54
+ class(layer), intent (in out ) :: self
55
+ class(layer), intent (in ) :: previous
56
+ real , intent (in ) :: gradient(:,:,:)
57
+
58
+ ! Backward pass from a 3-d layer downstream currently implemented
59
+ ! only for reshape3d layer
60
+ select type (this_layer = > self % p)
61
+
62
+ type is (reshape3d_layer)
63
+
64
+ ! Upstream layers permitted: input1d, dense, flatten
65
+ select type (prev_layer = > previous % p)
66
+ type is (input1d_layer)
67
+ call this_layer % backward(prev_layer % output, gradient)
68
+ type is (dense_layer)
69
+ call this_layer % backward(prev_layer % output, gradient)
70
+ type is (flatten_layer)
71
+ call this_layer % backward(prev_layer % output, gradient)
72
+ end select
73
+
74
+ end select
75
+
76
+ end subroutine backward_3d
48
77
49
78
50
79
pure module subroutine forward(self, input)
@@ -68,38 +97,56 @@ pure module subroutine forward(self, input)
68
97
69
98
type is (conv2d_layer)
70
99
71
- ! Upstream layers permitted: input3d, conv2d, maxpool2d
100
+ ! Upstream layers permitted: input3d, conv2d, maxpool2d, reshape3d
72
101
select type (prev_layer = > input % p)
73
102
type is (input3d_layer)
74
103
call this_layer % forward(prev_layer % output)
75
104
type is (conv2d_layer)
76
105
call this_layer % forward(prev_layer % output)
77
106
type is (maxpool2d_layer)
78
107
call this_layer % forward(prev_layer % output)
108
+ type is (reshape3d_layer)
109
+ call this_layer % forward(prev_layer % output)
79
110
end select
80
111
81
112
type is (maxpool2d_layer)
82
113
83
- ! Upstream layers permitted: input3d, conv2d, maxpool2d
114
+ ! Upstream layers permitted: input3d, conv2d, maxpool2d, reshape3d
84
115
select type (prev_layer = > input % p)
85
116
type is (input3d_layer)
86
117
call this_layer % forward(prev_layer % output)
87
118
type is (conv2d_layer)
88
119
call this_layer % forward(prev_layer % output)
89
120
type is (maxpool2d_layer)
90
121
call this_layer % forward(prev_layer % output)
122
+ type is (reshape3d_layer)
123
+ call this_layer % forward(prev_layer % output)
91
124
end select
92
125
93
126
type is (flatten_layer)
94
127
95
- ! Upstream layers permitted: input3d, conv2d, maxpool2d
128
+ ! Upstream layers permitted: input3d, conv2d, maxpool2d, reshape3d
96
129
select type (prev_layer = > input % p)
97
130
type is (input3d_layer)
98
131
call this_layer % forward(prev_layer % output)
99
132
type is (conv2d_layer)
100
133
call this_layer % forward(prev_layer % output)
101
134
type is (maxpool2d_layer)
102
135
call this_layer % forward(prev_layer % output)
136
+ type is (reshape3d_layer)
137
+ call this_layer % forward(prev_layer % output)
138
+ end select
139
+
140
+ type is (reshape3d_layer)
141
+
142
+ ! Upstream layers permitted: input1d, dense, flatten
143
+ select type (prev_layer = > input % p)
144
+ type is (input1d_layer)
145
+ call this_layer % forward(prev_layer % output)
146
+ type is (dense_layer)
147
+ call this_layer % forward(prev_layer % output)
148
+ type is (flatten_layer)
149
+ call this_layer % forward(prev_layer % output)
103
150
end select
104
151
105
152
end select
@@ -141,8 +188,10 @@ pure module subroutine get_output_3d(self, output)
141
188
allocate (output, source= this_layer % output)
142
189
type is (maxpool2d_layer)
143
190
allocate (output, source= this_layer % output)
191
+ type is (reshape3d_layer)
192
+ allocate (output, source= this_layer % output)
144
193
class default
145
- error stop ' 3-d output can only be read from an input3d, conv2d , or maxpool2d layer.'
194
+ error stop ' 3-d output can only be read from a conv2d, input3d, maxpool2d , or reshape3d layer.'
146
195
147
196
end select
148
197
0 commit comments