@@ -291,6 +291,69 @@ module function output_3d(self, input) result(res)
291
291
end function output_3d
292
292
293
293
294
+ module function output_batch_1d (self , input ) result(res)
295
+ class(network), intent (in out ) :: self
296
+ real , intent (in ) :: input(:,:)
297
+ real , allocatable :: res(:,:)
298
+ integer :: i, batch_size, num_layers, output_size
299
+
300
+ num_layers = size (self % layers)
301
+ batch_size = size (input, dim= rank(input))
302
+ output_size = product (self % layers(num_layers) % layer_shape)
303
+
304
+ allocate (res(output_size, batch_size))
305
+
306
+ batch: do concurrent(i = 1 :size (res, dim= 2 ))
307
+
308
+ call self % forward(input(:,i))
309
+
310
+ select type (output_layer = > self % layers(num_layers) % p)
311
+ type is (dense_layer)
312
+ res(:,i) = output_layer % output
313
+ type is (flatten_layer)
314
+ res(:,i) = output_layer % output
315
+ class default
316
+ error stop ' network % output not implemented for this output layer'
317
+ end select
318
+
319
+ end do batch
320
+
321
+ end function output_batch_1d
322
+
323
+
324
+ module function output_batch_3d (self , input ) result(res)
325
+ class(network), intent (in out ) :: self
326
+ real , intent (in ) :: input(:,:,:,:)
327
+ real , allocatable :: res(:,:)
328
+ integer :: i, batch_size, num_layers, output_size
329
+
330
+ num_layers = size (self % layers)
331
+ batch_size = size (input, dim= rank(input))
332
+ output_size = product (self % layers(num_layers) % layer_shape)
333
+
334
+ allocate (res(output_size, batch_size))
335
+
336
+ batch: do concurrent(i = 1 :batch_size)
337
+
338
+ call self % forward(input(:,:,:,i))
339
+
340
+ select type (output_layer = > self % layers(num_layers) % p)
341
+ type is (conv2d_layer)
342
+ ! FIXME flatten the result for now; find a better solution
343
+ res(:,i) = pack (output_layer % output, .true. )
344
+ type is (dense_layer)
345
+ res(:,i) = output_layer % output
346
+ type is (flatten_layer)
347
+ res(:,i) = output_layer % output
348
+ class default
349
+ error stop ' network % output not implemented for this output layer'
350
+ end select
351
+
352
+ end do batch
353
+
354
+ end function output_batch_3d
355
+
356
+
294
357
module subroutine print_info (self )
295
358
class(network), intent (in ) :: self
296
359
call self % layers % print_info()
0 commit comments