diff --git a/CMakeLists.txt b/CMakeLists.txt
index bd48e993..0fd1be6d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -26,6 +26,8 @@ add_library(neural
   src/nf.f90
   src/nf/nf_activation.f90
   src/nf/nf_base_layer.f90
+  src/nf/nf_batchnorm_layer.f90
+  src/nf/nf_batchnorm_layer_submodule.f90
   src/nf/nf_conv2d_layer.f90
   src/nf/nf_conv2d_layer_submodule.f90
   src/nf/nf_datasets.f90
diff --git a/src/nf.f90 b/src/nf.f90
index eb2a903a..b6b90b4f 100644
--- a/src/nf.f90
+++ b/src/nf.f90
@@ -3,7 +3,7 @@ module nf
   use nf_datasets_mnist, only: label_digits, load_mnist
   use nf_layer, only: layer
   use nf_layer_constructors, only: &
-    conv2d, dense, flatten, input, maxpool2d, reshape
+    batchnorm, conv2d, dense, flatten, input, maxpool2d, reshape
   use nf_network, only: network
   use nf_optimizers, only: sgd, rmsprop, adam, adagrad
   use nf_activation, only: activation_function, elu, exponential,  &
diff --git a/src/nf/nf_batchnorm_layer.f90 b/src/nf/nf_batchnorm_layer.f90
new file mode 100644
index 00000000..193d5ef3
--- /dev/null
+++ b/src/nf/nf_batchnorm_layer.f90
@@ -0,0 +1,109 @@
+module nf_batchnorm_layer
+
+  !! This module provides a batch normalization `batchnorm_layer` type.
+
+  use nf_base_layer, only: base_layer
+  implicit none
+
+  private
+  public :: batchnorm_layer
+
+  type, extends(base_layer) :: batchnorm_layer
+
+    integer :: num_features
+    real, allocatable :: gamma(:)
+    real, allocatable :: beta(:)
+    real, allocatable :: running_mean(:)
+    real, allocatable :: running_var(:)
+    real, allocatable :: input(:,:)
+    real, allocatable :: output(:,:)
+    real, allocatable :: gamma_grad(:)
+    real, allocatable :: beta_grad(:)
+    real, allocatable :: input_grad(:,:)
+    real :: epsilon = 1e-5
+
+  contains
+
+    procedure :: forward
+    procedure :: backward
+    procedure :: get_gradients
+    procedure :: get_num_params
+    procedure :: get_params
+    procedure :: init
+    procedure :: set_params
+
+  end type batchnorm_layer
+
+  interface batchnorm_layer
+    pure module function batchnorm_layer_cons(num_features) result(res)
+      !! `batchnorm_layer` constructor function
+      integer, intent(in) :: num_features
+      type(batchnorm_layer) :: res
+    end function batchnorm_layer_cons
+  end interface batchnorm_layer
+
+  interface
+
+    module subroutine init(self, input_shape)
+      !! Initialize the layer data structures.
+      !!
+      !! This is a deferred procedure from the `base_layer` abstract type.
+      class(batchnorm_layer), intent(in out) :: self
+        !! A `batchnorm_layer` instance
+      integer, intent(in) :: input_shape(:)
+        !! Input layer dimensions
+    end subroutine init
+
+    pure module subroutine forward(self, input)
+      !! Apply a forward pass on the `batchnorm_layer`.
+      class(batchnorm_layer), intent(in out) :: self
+        !! A `batchnorm_layer` instance
+      real, intent(in) :: input(:,:)
+        !! Input data
+    end subroutine forward
+
+    pure module subroutine backward(self, input, gradient)
+      !! Apply a backward pass on the `batchnorm_layer`.
+      class(batchnorm_layer), intent(in out) :: self
+        !! A `batchnorm_layer` instance
+      real, intent(in) :: input(:,:)
+        !! Input data (previous layer)
+      real, intent(in) :: gradient(:,:)
+        !! Gradient (next layer)
+    end subroutine backward
+
+    pure module function get_num_params(self) result(num_params)
+      !! Get the number of parameters in the layer.
+      class(batchnorm_layer), intent(in) :: self
+        !! A `batchnorm_layer` instance
+      integer :: num_params
+        !! Number of parameters
+    end function get_num_params
+
+    pure module function get_params(self) result(params)
+      !! Return the parameters (gamma, beta, running_mean, running_var) of this layer.
+      class(batchnorm_layer), intent(in) :: self
+        !! A `batchnorm_layer` instance
+      real, allocatable :: params(:)
+        !! Parameters to get
+    end function get_params
+
+    pure module function get_gradients(self) result(gradients)
+      !! Return the gradients of this layer.
+      class(batchnorm_layer), intent(in) :: self
+        !! A `batchnorm_layer` instance
+      real, allocatable :: gradients(:)
+        !! Gradients to get
+    end function get_gradients
+
+    module subroutine set_params(self, params)
+      !! Set the parameters of the layer.
+      class(batchnorm_layer), intent(in out) :: self
+        !! A `batchnorm_layer` instance
+      real, intent(in) :: params(:)
+        !! Parameters to set
+    end subroutine set_params
+
+  end interface
+
+end module nf_batchnorm_layer
diff --git a/src/nf/nf_batchnorm_layer_submodule.f90 b/src/nf/nf_batchnorm_layer_submodule.f90
new file mode 100644
index 00000000..9f3d2a82
--- /dev/null
+++ b/src/nf/nf_batchnorm_layer_submodule.f90
@@ -0,0 +1,105 @@
+submodule(nf_batchnorm_layer) nf_batchnorm_layer_submodule
+
+  implicit none
+
+contains
+
+  pure module function batchnorm_layer_cons(num_features) result(res)
+    implicit none
+    integer, intent(in) :: num_features
+    type(batchnorm_layer) :: res
+
+    res % num_features = num_features
+    allocate(res % gamma(num_features), source=1.0)
+    allocate(res % beta(num_features))
+    allocate(res % running_mean(num_features), source=0.0)
+    allocate(res % running_var(num_features), source=1.0)
+    allocate(res % input(num_features, num_features))
+    allocate(res % output(num_features, num_features))
+    allocate(res % gamma_grad(num_features))
+    allocate(res % beta_grad(num_features))
+    allocate(res % input_grad(num_features, num_features))
+
+  end function batchnorm_layer_cons
+
+  module subroutine init(self, input_shape)
+    implicit none
+    class(batchnorm_layer), intent(in out) :: self
+    integer, intent(in) :: input_shape(:)
+
+    self % input = 0
+    self % output = 0
+
+    ! Initialize gamma, beta, running_mean, and running_var
+    self % gamma = 1.0
+    self % beta = 0.0
+    self % running_mean = 0.0
+    self % running_var = 1.0
+
+  end subroutine init
+
+  pure module subroutine forward(self, input)
+    implicit none
+    class(batchnorm_layer), intent(in out) :: self
+    real, intent(in) :: input(:,:)
+
+    ! Store input for backward pass
+    self % input = input
+
+    associate( &
+      ! Normalize the input
+      normalized_input => (input - reshape(self % running_mean, shape(input, 1))) &
+        / sqrt(reshape(self % running_var, shape(input, 1)) + self % epsilon) &
+    )
+
+      ! Batch normalization forward pass
+      self % output = reshape(self % gamma, shape(input, 1)) * normalized_input &
+        + reshape(self % beta, shape(input, 1))
+
+    end associate
+
+  end subroutine forward
+
+  pure module subroutine backward(self, input, gradient)
+    implicit none
+    class(batchnorm_layer), intent(in out) :: self
+    real, intent(in) :: input(:,:)
+    real, intent(in) :: gradient(:,:)
+
+    ! Calculate gradients for gamma, beta
+    self % gamma_grad = sum(gradient * (input - reshape(self % running_mean, shape(input, 1))) &
+      / sqrt(reshape(self % running_var, shape(input, 1)) + self % epsilon), dim=2)
+    self % beta_grad = sum(gradient, dim=2)
+
+    ! Calculate gradients for input
+    self % input_grad = gradient * reshape(self % gamma, shape(input, 1)) &
+      / sqrt(reshape(self % running_var, shape(input, 1)) + self % epsilon)
+
+  end subroutine backward
+
+  pure module function get_num_params(self) result(num_params)
+    class(batchnorm_layer), intent(in) :: self
+    integer :: num_params
+    num_params = 2 * self % num_features
+  end function get_num_params
+
+  pure module function get_params(self) result(params)
+    class(batchnorm_layer), intent(in) :: self
+    real, allocatable :: params(:)
+    params = [self % gamma, self % beta]
+  end function get_params
+
+  pure module function get_gradients(self) result(gradients)
+    class(batchnorm_layer), intent(in) :: self
+    real, allocatable :: gradients(:)
+    gradients = [self % gamma_grad, self % beta_grad]
+  end function get_gradients
+
+  module subroutine set_params(self, params)
+    class(batchnorm_layer), intent(in out) :: self
+    real, intent(in) :: params(:)
+    self % gamma = params(1:self % num_features)
+    self % beta = params(self % num_features+1:2*self % num_features)
+  end subroutine set_params
+
+end submodule nf_batchnorm_layer_submodule
diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90
index ce9a7244..b036f1bd 100644
--- a/src/nf/nf_layer_constructors.f90
+++ b/src/nf/nf_layer_constructors.f90
@@ -8,7 +8,7 @@ module nf_layer_constructors
   implicit none
 
   private
-  public :: conv2d, dense, flatten, input, maxpool2d, reshape
+  public :: batchnorm, conv2d, dense, flatten, input, maxpool2d, reshape
 
   interface input
 
@@ -106,6 +106,25 @@ pure module function flatten() result(res)
         !! Resulting layer instance
     end function flatten
 
+    pure module function batchnorm(num_features) result(res)
+      !! Batch normalization layer constructor.
+      !!
+      !! This layer is for adding batch normalization to the network.
+      !! A batch normalization layer can be used after conv2d or dense layers.
+      !!
+      !! Example:
+      !!
+      !! ```
+      !! use nf, only :: batchnorm, layer
+      !! type(layer) :: batchnorm_layer
+      !! batchnorm_layer = batchnorm(num_features = 64)
+      !! ```
+      integer, intent(in) :: num_features
+        !! Number of features in the Layer
+      type(layer) :: res
+        !! Resulting layer instance
+    end function batchnorm
+
     pure module function conv2d(filters, kernel_size, activation) result(res)
       !! 2-d convolutional layer constructor.
       !!
diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90
index 002a83ba..914df2f7 100644
--- a/src/nf/nf_layer_constructors_submodule.f90
+++ b/src/nf/nf_layer_constructors_submodule.f90
@@ -1,6 +1,7 @@
 submodule(nf_layer_constructors) nf_layer_constructors_submodule
 
   use nf_layer, only: layer
+  use nf_batchnorm_layer, only: batchnorm_layer
   use nf_conv2d_layer, only: conv2d_layer
   use nf_dense_layer, only: dense_layer
   use nf_flatten_layer, only: flatten_layer
@@ -14,6 +15,13 @@
 
 contains
 
+  pure module function batchnorm(num_features) result(res)
+    integer, intent(in) :: num_features
+    type(layer) :: res
+    res % name = 'batchnorm'
+    allocate(res % p, source=batchnorm_layer(num_features))
+  end function batchnorm
+
   pure module function conv2d(filters, kernel_size, activation) result(res)
     integer, intent(in) :: filters
     integer, intent(in) :: kernel_size
diff --git a/src/nf/nf_layer_submodule.f90 b/src/nf/nf_layer_submodule.f90
index 07467643..94d9d17e 100644
--- a/src/nf/nf_layer_submodule.f90
+++ b/src/nf/nf_layer_submodule.f90
@@ -1,6 +1,7 @@
 submodule(nf_layer) nf_layer_submodule
 
   use iso_fortran_env, only: stderr => error_unit
+  use nf_batchnorm_layer, only: batchnorm_layer
   use nf_conv2d_layer, only: conv2d_layer
   use nf_dense_layer, only: dense_layer
   use nf_flatten_layer, only: flatten_layer
diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90
index 5bafb7cf..ecff74d2 100644
--- a/src/nf/nf_network_submodule.f90
+++ b/src/nf/nf_network_submodule.f90
@@ -10,7 +10,7 @@
   use nf_io_hdf5, only: get_hdf5_dataset
   use nf_keras, only: get_keras_h5_layers, keras_layer
   use nf_layer, only: layer
-  use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
+  use nf_layer_constructors, only: batchnorm, conv2d, dense, flatten, input, maxpool2d, reshape
   use nf_loss, only: quadratic_derivative
   use nf_optimizers, only: optimizer_base_type, sgd
   use nf_parallel, only: tile_indices
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
index 26646ec1..b4ee8202 100644
--- a/test/CMakeLists.txt
+++ b/test/CMakeLists.txt
@@ -16,6 +16,7 @@ foreach(execid
   cnn_from_keras
   conv2d_network
   optimizers
+  batchnorm_layer
   )
   add_executable(test_${execid} test_${execid}.f90)
   target_link_libraries(test_${execid} PRIVATE neural h5fortran::h5fortran jsonfortran::jsonfortran ${LIBS})
diff --git a/test/test_batchnorm_layer.f90 b/test/test_batchnorm_layer.f90
new file mode 100644
index 00000000..473b22de
--- /dev/null
+++ b/test/test_batchnorm_layer.f90
@@ -0,0 +1,65 @@
+program test_batchnorm_layer
+
+  use iso_fortran_env, only: stderr => error_unit
+  use nf, only: batchnorm, layer
+  use nf_batchnorm_layer, only: batchnorm_layer
+
+  implicit none
+
+  type(layer) :: bn_layer
+  integer, parameter :: num_features = 64
+  real, allocatable :: sample_input(:,:)
+  real, allocatable :: output(:,:)
+  real, allocatable :: gradient(:,:)
+  integer, parameter :: input_shape(1) = [num_features]
+  real, allocatable :: gamma_grad(:), beta_grad(:)
+  real, parameter :: tolerance = 1e-7
+  logical :: ok = .true.
+
+  bn_layer = batchnorm(num_features)
+
+  if (.not. bn_layer % name == 'batchnorm') then
+    ok = .false.
+    write(stderr, '(a)') 'batchnorm layer has its name set correctly.. failed'
+  end if
+
+  if (bn_layer % initialized) then
+    ok = .false.
+    write(stderr, '(a)') 'batchnorm layer should not be marked as initialized yet.. failed'
+  end if
+
+  ! Initialize sample input and gradient
+  allocate(sample_input(num_features, 1))
+  allocate(gradient(num_features, 1))
+  sample_input = 1.0
+  gradient = 2.0
+
+  !TODO run forward and backward passes directly on the batchnorm_layer instance
+  !TODO since we don't yet support tiying in with the input layer.
+
+  !TODO Retrieve output and check normalization
+  !call bn_layer % get_output(output)
+  !if (.not. all(abs(output - sample_input) < tolerance)) then
+  !  ok = .false.
+  !  write(stderr, '(a)') 'batchnorm layer output should be close to input.. failed'
+  !end if
+
+  !TODO Retrieve gamma and beta gradients
+  !allocate(gamma_grad(num_features))
+  !allocate(beta_grad(num_features))
+  !call bn_layer % get_gradients(gamma_grad, beta_grad)
+
+  !if (.not. all(beta_grad == sum(gradient))) then
+  !  ok = .false.
+  !  write(stderr, '(a)') 'batchnorm layer beta gradients are incorrect.. failed'
+  !end if
+
+  ! Report test results
+  if (ok) then
+    print '(a)', 'test_batchnorm_layer: All tests passed.'
+  else
+    write(stderr, '(a)') 'test_batchnorm_layer: One or more tests failed.'
+    stop 1
+  end if
+
+end program test_batchnorm_layer