Skip to content

Commit d052bce

Browse files
ggoymanmilancurcic
andauthored
Class-based activation functions (#126)
* Implementation of activation_function class for 1d activations * 3d activations implemented using activation_function type * get_name function added to the activation_function type * Activation_function instances are now passed to contructors * removal of redundant use statements * Small fix to make the test build * Tidy up and formatting * Formatting * Set alpha defaults from Keras * Enable leaky ReLU * Add tests for setting alpha values to parametric activations (ELU and leaky ReLU) * Bump version --------- Co-authored-by: milancurcic <[email protected]>
1 parent f328c8d commit d052bce

17 files changed

+768
-685
lines changed

CMakeLists.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ include(cmake/json.cmake)
2424
# library to archive (libneural.a)
2525
add_library(neural
2626
src/nf.f90
27-
src/nf/nf_activation_1d.f90
28-
src/nf/nf_activation_3d.f90
27+
src/nf/nf_activation.f90
2928
src/nf/nf_base_layer.f90
3029
src/nf/nf_conv2d_layer.f90
3130
src/nf/nf_conv2d_layer_submodule.f90

example/cnn_mnist.f90

+5-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ program cnn_mnist
22

33
use nf, only: network, sgd, &
44
input, conv2d, maxpool2d, flatten, dense, reshape, &
5-
load_mnist, label_digits
5+
load_mnist, label_digits, softmax, relu
66

77
implicit none
88

@@ -24,11 +24,11 @@ program cnn_mnist
2424
net = network([ &
2525
input(784), &
2626
reshape([1,28,28]), &
27-
conv2d(filters=8, kernel_size=3, activation='relu'), &
27+
conv2d(filters=8, kernel_size=3, activation=relu()), &
2828
maxpool2d(pool_size=2), &
29-
conv2d(filters=16, kernel_size=3, activation='relu'), &
29+
conv2d(filters=16, kernel_size=3, activation=relu()), &
3030
maxpool2d(pool_size=2), &
31-
dense(10, activation='softmax') &
31+
dense(10, activation=softmax()) &
3232
])
3333

3434
call net % print_info()
@@ -67,4 +67,4 @@ real function accuracy(net, x, y)
6767
accuracy = real(good) / size(x, dim=2)
6868
end function accuracy
6969

70-
end program cnn_mnist
70+
end program cnn_mnist

fpm.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name = "neural-fortran"
2-
version = "0.11.0"
2+
version = "0.12.0"
33
license = "MIT"
44
author = "Milan Curcic"
55
maintainer = "[email protected]"

src/nf.f90

+3
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,7 @@ module nf
66
conv2d, dense, flatten, input, maxpool2d, reshape
77
use nf_network, only: network
88
use nf_optimizers, only: sgd
9+
use nf_activation, only: activation_function, elu, exponential, &
10+
gaussian, linear, relu, leaky_relu, &
11+
sigmoid, softmax, softplus, step, tanhf
912
end module nf

0 commit comments

Comments
 (0)