Skip to content

Commit 51aad71

Browse files
authored
Merge pull request #92 from milancurcic/rename-output-to-predict
Rename output -> predict for consistency with Keras
2 parents 20ec349 + 78f762f commit 51aad71

11 files changed

+43
-43
lines changed

example/cnn.f90

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ program cnn
2727
allocate(x(3,32,32))
2828
call random_number(x)
2929

30-
print *, 'Output:', net % output(x)
30+
print *, 'Output:', net % predict(x)
3131

3232
end program cnn

example/cnn_from_keras.f90

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ real function accuracy(net, x, y)
4848
integer :: i, good
4949
good = 0
5050
do i = 1, size(x, dim=4)
51-
if (all(maxloc(net % output(x(:,:,:,i))) == maxloc(y(:,i)))) then
51+
if (all(maxloc(net % predict(x(:,:,:,i))) == maxloc(y(:,i)))) then
5252
good = good + 1
5353
end if
5454
end do

example/dense_from_keras.f90

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ real function accuracy(net, x, y)
4242
integer :: i, good
4343
good = 0
4444
do i = 1, size(x, dim=2)
45-
if (all(maxloc(net % output(x(:,i))) == maxloc(y(:,i)))) then
45+
if (all(maxloc(net % predict(x(:,i))) == maxloc(y(:,i)))) then
4646
good = good + 1
4747
end if
4848
end do

example/mnist.f90

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ real function accuracy(net, x, y)
5252
integer :: i, good
5353
good = 0
5454
do i = 1, size(x, dim=2)
55-
if (all(maxloc(net % output(x(:,i))) == maxloc(y(:,i)))) then
55+
if (all(maxloc(net % predict(x(:,i))) == maxloc(y(:,i)))) then
5656
good = good + 1
5757
end if
5858
end do

example/simple.f90

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ program simple
2727
call net % update(1.)
2828

2929
if (mod(n, 50) == 0) &
30-
print '(i4,2(3x,f8.6))', n, net % output(x)
30+
print '(i4,2(3x,f8.6))', n, net % predict(x)
3131

3232
end do
3333

example/sine.f90

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ program sine
3434
call net % update(1.)
3535

3636
if (mod(n, 10000) == 0) then
37-
ypred = [(net % output([xtest(i)]), i = 1, test_size)]
37+
ypred = [(net % predict([xtest(i)]), i = 1, test_size)]
3838
print '(i0,1x,f9.6)', n, sum((ypred - ytest)**2) / size(ypred)
3939
end if
4040

src/nf/nf_network.f90

+13-13
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ module nf_network
2323

2424
procedure, private :: forward_1d
2525
procedure, private :: forward_3d
26-
procedure, private :: output_1d
27-
procedure, private :: output_3d
28-
procedure, private :: output_batch_1d
29-
procedure, private :: output_batch_3d
26+
procedure, private :: predict_1d
27+
procedure, private :: predict_3d
28+
procedure, private :: predict_batch_1d
29+
procedure, private :: predict_batch_3d
3030

3131
generic :: forward => forward_1d, forward_3d
32-
generic :: output => output_1d, output_3d, output_batch_1d, output_batch_3d
32+
generic :: predict => predict_1d, predict_3d, predict_batch_1d, predict_batch_3d
3333

3434
end type network
3535

@@ -89,45 +89,45 @@ end subroutine forward_3d
8989

9090
interface output
9191

92-
module function output_1d(self, input) result(res)
92+
module function predict_1d(self, input) result(res)
9393
!! Return the output of the network given the input 1-d array.
9494
class(network), intent(in out) :: self
9595
!! Network instance
9696
real, intent(in) :: input(:)
9797
!! Input data
9898
real, allocatable :: res(:)
9999
!! Output of the network
100-
end function output_1d
100+
end function predict_1d
101101

102-
module function output_3d(self, input) result(res)
102+
module function predict_3d(self, input) result(res)
103103
!! Return the output of the network given the input 3-d array.
104104
class(network), intent(in out) :: self
105105
!! Network instance
106106
real, intent(in) :: input(:,:,:)
107107
!! Input data
108108
real, allocatable :: res(:)
109109
!! Output of the network
110-
end function output_3d
110+
end function predict_3d
111111

112-
module function output_batch_1d(self, input) result(res)
112+
module function predict_batch_1d(self, input) result(res)
113113
!! Return the output of the network given an input batch of 3-d data.
114114
class(network), intent(in out) :: self
115115
!! Network instance
116116
real, intent(in) :: input(:,:)
117117
!! Input data; the last dimension is the batch
118118
real, allocatable :: res(:,:)
119119
!! Output of the network; the last dimension is the batch
120-
end function output_batch_1d
120+
end function predict_batch_1d
121121

122-
module function output_batch_3d(self, input) result(res)
122+
module function predict_batch_3d(self, input) result(res)
123123
!! Return the output of the network given an input batch of 3-d data.
124124
class(network), intent(in out) :: self
125125
!! Network instance
126126
real, intent(in) :: input(:,:,:,:)
127127
!! Input data; the last dimension is the batch
128128
real, allocatable :: res(:,:)
129129
!! Output of the network; the last dimension is the batch
130-
end function output_batch_3d
130+
end function predict_batch_3d
131131

132132
end interface output
133133

src/nf/nf_network_submodule.f90

+20-20
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ pure module subroutine forward_3d(self, input)
244244
end subroutine forward_3d
245245

246246

247-
module function output_1d(self, input) result(res)
247+
module function predict_1d(self, input) result(res)
248248
class(network), intent(in out) :: self
249249
real, intent(in) :: input(:)
250250
real, allocatable :: res(:)
@@ -263,10 +263,10 @@ module function output_1d(self, input) result(res)
263263
error stop 'network % output not implemented for this output layer'
264264
end select
265265

266-
end function output_1d
266+
end function predict_1d
267267

268268

269-
module function output_3d(self, input) result(res)
269+
module function predict_3d(self, input) result(res)
270270
class(network), intent(in out) :: self
271271
real, intent(in) :: input(:,:,:)
272272
real, allocatable :: res(:)
@@ -288,10 +288,10 @@ module function output_3d(self, input) result(res)
288288
error stop 'network % output not implemented for this output layer'
289289
end select
290290

291-
end function output_3d
291+
end function predict_3d
292292

293293

294-
module function output_batch_1d(self, input) result(res)
294+
module function predict_batch_1d(self, input) result(res)
295295
class(network), intent(in out) :: self
296296
real, intent(in) :: input(:,:)
297297
real, allocatable :: res(:,:)
@@ -318,10 +318,10 @@ module function output_batch_1d(self, input) result(res)
318318

319319
end do batch
320320

321-
end function output_batch_1d
321+
end function predict_batch_1d
322322

323323

324-
module function output_batch_3d(self, input) result(res)
324+
module function predict_batch_3d(self, input) result(res)
325325
class(network), intent(in out) :: self
326326
real, intent(in) :: input(:,:,:,:)
327327
real, allocatable :: res(:,:)
@@ -335,23 +335,23 @@ module function output_batch_3d(self, input) result(res)
335335

336336
batch: do concurrent(i = 1:batch_size)
337337

338-
call self % forward(input(:,:,:,i))
338+
call self % forward(input(:,:,:,i))
339339

340-
select type(output_layer => self % layers(num_layers) % p)
341-
type is(conv2d_layer)
342-
!FIXME flatten the result for now; find a better solution
343-
res(:,i) = pack(output_layer % output, .true.)
344-
type is(dense_layer)
345-
res(:,i) = output_layer % output
346-
type is(flatten_layer)
347-
res(:,i) = output_layer % output
348-
class default
349-
error stop 'network % output not implemented for this output layer'
350-
end select
340+
select type(output_layer => self % layers(num_layers) % p)
341+
type is(conv2d_layer)
342+
!FIXME flatten the result for now; find a better solution
343+
res(:,i) = pack(output_layer % output, .true.)
344+
type is(dense_layer)
345+
res(:,i) = output_layer % output
346+
type is(flatten_layer)
347+
res(:,i) = output_layer % output
348+
class default
349+
error stop 'network % output not implemented for this output layer'
350+
end select
351351

352352
end do batch
353353

354-
end function output_batch_3d
354+
end function predict_batch_3d
355355

356356

357357
module subroutine print_info(self)

test/test_cnn_from_keras.f90

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ real function accuracy(net, x, y)
5959
integer :: i, good
6060
good = 0
6161
do i = 1, size(x, dim=4)
62-
if (all(maxloc(net % output(x(:,:,:,i))) == maxloc(y(:,i)))) then
62+
if (all(maxloc(net % predict(x(:,:,:,i))) == maxloc(y(:,i)))) then
6363
good = good + 1
6464
end if
6565
end do

test/test_dense_network.f90

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ program test_dense_network
1616
ok = .false.
1717
end if
1818

19-
if (.not. all(net % output([0.]) == 0.5)) then
19+
if (.not. all(net % predict([0.]) == 0.5)) then
2020
write(stderr, '(a)') &
2121
'dense network should output exactly 0.5 for input 0.. failed'
2222
ok = .false.
@@ -35,7 +35,7 @@ program test_dense_network
3535
call net % forward(x)
3636
call net % backward(y)
3737
call net % update(1.)
38-
if (all(abs(net % output(x) - y) < tolerance)) exit
38+
if (all(abs(net % predict(x) - y) < tolerance)) exit
3939
end do
4040

4141
if (.not. n <= num_iterations) then

test/test_dense_network_from_keras.f90

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ real function accuracy(net, x, y)
9090
integer :: i, good
9191
good = 0
9292
do i = 1, size(x, dim=2)
93-
if (all(maxloc(net % output(x(:,i))) == maxloc(y(:,i)))) then
93+
if (all(maxloc(net % predict(x(:,i))) == maxloc(y(:,i)))) then
9494
good = good + 1
9595
end if
9696
end do

0 commit comments

Comments
 (0)