@@ -244,7 +244,7 @@ pure module subroutine forward_3d(self, input)
244
244
end subroutine forward_3d
245
245
246
246
247
- module function output_1d (self , input ) result(res)
247
+ module function predict_1d (self , input ) result(res)
248
248
class(network), intent (in out ) :: self
249
249
real , intent (in ) :: input(:)
250
250
real , allocatable :: res(:)
@@ -263,10 +263,10 @@ module function output_1d(self, input) result(res)
263
263
error stop ' network % output not implemented for this output layer'
264
264
end select
265
265
266
- end function output_1d
266
+ end function predict_1d
267
267
268
268
269
- module function output_3d (self , input ) result(res)
269
+ module function predict_3d (self , input ) result(res)
270
270
class(network), intent (in out ) :: self
271
271
real , intent (in ) :: input(:,:,:)
272
272
real , allocatable :: res(:)
@@ -288,10 +288,10 @@ module function output_3d(self, input) result(res)
288
288
error stop ' network % output not implemented for this output layer'
289
289
end select
290
290
291
- end function output_3d
291
+ end function predict_3d
292
292
293
293
294
- module function output_batch_1d (self , input ) result(res)
294
+ module function predict_batch_1d (self , input ) result(res)
295
295
class(network), intent (in out ) :: self
296
296
real , intent (in ) :: input(:,:)
297
297
real , allocatable :: res(:,:)
@@ -318,10 +318,10 @@ module function output_batch_1d(self, input) result(res)
318
318
319
319
end do batch
320
320
321
- end function output_batch_1d
321
+ end function predict_batch_1d
322
322
323
323
324
- module function output_batch_3d (self , input ) result(res)
324
+ module function predict_batch_3d (self , input ) result(res)
325
325
class(network), intent (in out ) :: self
326
326
real , intent (in ) :: input(:,:,:,:)
327
327
real , allocatable :: res(:,:)
@@ -335,23 +335,23 @@ module function output_batch_3d(self, input) result(res)
335
335
336
336
batch: do concurrent(i = 1 :batch_size)
337
337
338
- call self % forward(input(:,:,:,i))
338
+ call self % forward(input(:,:,:,i))
339
339
340
- select type (output_layer = > self % layers(num_layers) % p)
341
- type is (conv2d_layer)
342
- ! FIXME flatten the result for now; find a better solution
343
- res(:,i) = pack (output_layer % output, .true. )
344
- type is (dense_layer)
345
- res(:,i) = output_layer % output
346
- type is (flatten_layer)
347
- res(:,i) = output_layer % output
348
- class default
349
- error stop ' network % output not implemented for this output layer'
350
- end select
340
+ select type (output_layer = > self % layers(num_layers) % p)
341
+ type is (conv2d_layer)
342
+ ! FIXME flatten the result for now; find a better solution
343
+ res(:,i) = pack (output_layer % output, .true. )
344
+ type is (dense_layer)
345
+ res(:,i) = output_layer % output
346
+ type is (flatten_layer)
347
+ res(:,i) = output_layer % output
348
+ class default
349
+ error stop ' network % output not implemented for this output layer'
350
+ end select
351
351
352
352
end do batch
353
353
354
- end function output_batch_3d
354
+ end function predict_batch_3d
355
355
356
356
357
357
module subroutine print_info (self )
0 commit comments