11
11
use nf_keras, only: get_keras_h5_layers, keras_layer
12
12
use nf_layer, only: layer
13
13
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
14
- use nf_loss, only: quadratic_derivative
14
+ use nf_loss, only: quadratic
15
15
use nf_optimizers, only: optimizer_base_type, sgd
16
16
use nf_parallel, only: tile_indices
17
17
use nf_activation, only: activation_function, &
@@ -280,11 +280,27 @@ pure function get_activation_by_name(activation_name) result(res)
280
280
281
281
end function get_activation_by_name
282
282
283
- pure module subroutine backward(self, output)
283
+ pure module subroutine backward(self, output, loss )
284
284
class(network), intent (in out ) :: self
285
285
real , intent (in ) :: output(:)
286
+ class(loss_type), intent (in ), optional :: loss
286
287
integer :: n, num_layers
287
288
289
+ ! Passing the loss instance is optional. If not provided, and if the
290
+ ! loss instance has not already been set, we default to the default quadratic. The
291
+ ! instantiation and initialization below of the loss instance is normally done
292
+ ! at the beginning of the network % train() method. However, if the user
293
+ ! wants to call network % backward() directly, for example if they use their
294
+ ! own custom mini-batching routine, we initialize the loss instance here as
295
+ ! well. If it's initialized already, this step is a cheap no-op.
296
+ if (.not. allocated (self % loss)) then
297
+ if (present (loss)) then
298
+ self % loss = loss
299
+ else
300
+ self % loss = quadratic()
301
+ end if
302
+ end if
303
+
288
304
num_layers = size (self % layers)
289
305
290
306
! Iterate backward over layers, from the output layer
@@ -297,7 +313,7 @@ pure module subroutine backward(self, output)
297
313
type is (dense_layer)
298
314
call self % layers(n) % backward( &
299
315
self % layers(n - 1 ), &
300
- quadratic_derivative (output, this_layer % output) &
316
+ self % loss % derivative (output, this_layer % output) &
301
317
)
302
318
end select
303
319
else
@@ -542,13 +558,14 @@ end subroutine set_params
542
558
543
559
544
560
module subroutine train (self , input_data , output_data , batch_size , &
545
- epochs , optimizer )
561
+ epochs , optimizer , loss )
546
562
class(network), intent (in out ) :: self
547
563
real , intent (in ) :: input_data(:,:)
548
564
real , intent (in ) :: output_data(:,:)
549
565
integer , intent (in ) :: batch_size
550
566
integer , intent (in ) :: epochs
551
567
class(optimizer_base_type), intent (in ), optional :: optimizer
568
+ class(loss_type), intent (in ), optional :: loss
552
569
class(optimizer_base_type), allocatable :: optimizer_
553
570
554
571
real :: pos
@@ -567,6 +584,14 @@ module subroutine train(self, input_data, output_data, batch_size, &
567
584
568
585
call self % optimizer % init(self % get_num_params())
569
586
587
+ ! Passing the loss instance is optional.
588
+ ! If not provided, we default to quadratic().
589
+ if (present (loss)) then
590
+ self % loss = loss
591
+ else
592
+ self % loss = quadratic()
593
+ end if
594
+
570
595
dataset_size = size (output_data, dim= 2 )
571
596
572
597
epoch_loop: do n = 1 , epochs
0 commit comments