Skip to content

Commit 20ec349

Browse files
authored
Merge pull request #90 from milancurcic/batch-inference
Batch inference
2 parents 75666fe + b710872 commit 20ec349

File tree

2 files changed

+86
-1
lines changed

2 files changed

+86
-1
lines changed

src/nf/nf_network.f90

+23-1
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@ module nf_network
2525
procedure, private :: forward_3d
2626
procedure, private :: output_1d
2727
procedure, private :: output_3d
28+
procedure, private :: output_batch_1d
29+
procedure, private :: output_batch_3d
2830

2931
generic :: forward => forward_1d, forward_3d
30-
generic :: output => output_1d, output_3d
32+
generic :: output => output_1d, output_3d, output_batch_1d, output_batch_3d
3133

3234
end type network
3335

@@ -107,6 +109,26 @@ module function output_3d(self, input) result(res)
107109
!! Output of the network
108110
end function output_3d
109111

112+
module function output_batch_1d(self, input) result(res)
113+
!! Return the output of the network given an input batch of 3-d data.
114+
class(network), intent(in out) :: self
115+
!! Network instance
116+
real, intent(in) :: input(:,:)
117+
!! Input data; the last dimension is the batch
118+
real, allocatable :: res(:,:)
119+
!! Output of the network; the last dimension is the batch
120+
end function output_batch_1d
121+
122+
module function output_batch_3d(self, input) result(res)
123+
!! Return the output of the network given an input batch of 3-d data.
124+
class(network), intent(in out) :: self
125+
!! Network instance
126+
real, intent(in) :: input(:,:,:,:)
127+
!! Input data; the last dimension is the batch
128+
real, allocatable :: res(:,:)
129+
!! Output of the network; the last dimension is the batch
130+
end function output_batch_3d
131+
110132
end interface output
111133

112134
interface

src/nf/nf_network_submodule.f90

+63
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,69 @@ module function output_3d(self, input) result(res)
291291
end function output_3d
292292

293293

294+
module function output_batch_1d(self, input) result(res)
295+
class(network), intent(in out) :: self
296+
real, intent(in) :: input(:,:)
297+
real, allocatable :: res(:,:)
298+
integer :: i, batch_size, num_layers, output_size
299+
300+
num_layers = size(self % layers)
301+
batch_size = size(input, dim=rank(input))
302+
output_size = product(self % layers(num_layers) % layer_shape)
303+
304+
allocate(res(output_size, batch_size))
305+
306+
batch: do concurrent(i = 1:size(res, dim=2))
307+
308+
call self % forward(input(:,i))
309+
310+
select type(output_layer => self % layers(num_layers) % p)
311+
type is(dense_layer)
312+
res(:,i) = output_layer % output
313+
type is(flatten_layer)
314+
res(:,i) = output_layer % output
315+
class default
316+
error stop 'network % output not implemented for this output layer'
317+
end select
318+
319+
end do batch
320+
321+
end function output_batch_1d
322+
323+
324+
module function output_batch_3d(self, input) result(res)
325+
class(network), intent(in out) :: self
326+
real, intent(in) :: input(:,:,:,:)
327+
real, allocatable :: res(:,:)
328+
integer :: i, batch_size, num_layers, output_size
329+
330+
num_layers = size(self % layers)
331+
batch_size = size(input, dim=rank(input))
332+
output_size = product(self % layers(num_layers) % layer_shape)
333+
334+
allocate(res(output_size, batch_size))
335+
336+
batch: do concurrent(i = 1:batch_size)
337+
338+
call self % forward(input(:,:,:,i))
339+
340+
select type(output_layer => self % layers(num_layers) % p)
341+
type is(conv2d_layer)
342+
!FIXME flatten the result for now; find a better solution
343+
res(:,i) = pack(output_layer % output, .true.)
344+
type is(dense_layer)
345+
res(:,i) = output_layer % output
346+
type is(flatten_layer)
347+
res(:,i) = output_layer % output
348+
class default
349+
error stop 'network % output not implemented for this output layer'
350+
end select
351+
352+
end do batch
353+
354+
end function output_batch_3d
355+
356+
294357
module subroutine print_info(self)
295358
class(network), intent(in) :: self
296359
call self % layers % print_info()

0 commit comments

Comments
 (0)