Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch BatchNorm #948

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 153 additions & 42 deletions equinox/nn/_batch_norm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from collections.abc import Hashable, Sequence
from typing import Optional, Union
from typing import Literal, Optional, Union

import jax
import jax.lax as lax
import jax.numpy as jnp
from jaxtyping import Array, Bool, Float, PRNGKeyArray
from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray

from .._misc import default_floating_dtype
from .._module import field
Expand Down Expand Up @@ -40,25 +40,69 @@ class BatchNorm(StatefulLayer, strict=True):
statistics updated. During inference then just the running statistics are used.
Whether the model is in training or inference mode should be toggled using
[`equinox.nn.inference_mode`][].

With `mode = "batch"` during training the batch mean and variance are used
for normalization. For inference the exponential running mean and ubiased
variance are used for normalization. This is in line with out other JAX
packages (e.g. haiku, flax) implement batch norm.

With `mode = "ema"` exponential running means and variances are kept. During
training the batch statistics are used to fill in the running statistics until
they are populated. During inference the running statistics are used for
normalization.

??? cite

[Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/abs/1502.03167)

```bibtex
@article{DBLP:journals/corr/IoffeS15,
author = {Sergey Ioffe and
Christian Szegedy},
title = {Batch Normalization: Accelerating Deep Network Training
by Reducing Internal Covariate Shift},
journal = {CoRR},
volume = {abs/1502.03167},
year = {2015},
url = {http://arxiv.org/abs/1502.03167},
eprinttype = {arXiv},
eprint = {1502.03167},
timestamp = {Mon, 13 Aug 2018 16:47:06 +0200},
biburl = {https://dblp.org/rec/journals/corr/IoffeS15.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
""" # noqa: E501

weight: Optional[Float[Array, "input_size"]]
bias: Optional[Float[Array, "input_size"]]
first_time_index: StateIndex[Bool[Array, ""]]
state_index: StateIndex[
tuple[Float[Array, "input_size"], Float[Array, "input_size"]]
ema_first_time_index: Optional[StateIndex[Bool[Array, ""]]]
ema_state_index: Optional[
StateIndex[tuple[Float[Array, "input_size"], Float[Array, "input_size"]]]
]
batch_counter: Optional[StateIndex[Int[Array, ""]]]
batch_state_index: Optional[
StateIndex[
tuple[
tuple[Float[Array, "input_size"], Float[Array, "input_size"]],
tuple[Float[Array, "input_size"], Float[Array, "input_size"]],
],
]
]
axis_name: Union[Hashable, Sequence[Hashable]]
inference: bool
input_size: int = field(static=True)
eps: float = field(static=True)
channelwise_affine: bool = field(static=True)
momentum: float = field(static=True)
mode: Literal["ema", "batch"] = field(static=True)

def __init__(
self,
input_size: int,
axis_name: Union[Hashable, Sequence[Hashable]],
mode: str = "ema",
eps: float = 1e-5,
channelwise_affine: bool = True,
momentum: float = 0.99,
Expand All @@ -71,6 +115,7 @@ def __init__(
- `axis_name`: The name of the batch axis to compute statistics over, as passed
to `axis_name` in `jax.vmap` or `jax.pmap`. Can also be a sequence (e.g. a
tuple or a list) of names, to compute statistics over multiple named axes.
- `mode`: The variant of batch norm to use, either 'ema' or 'batch'.
- `eps`: Value added to the denominator for numerical stability.
- `channelwise_affine`: Whether the module has learnable channel-wise affine
parameters.
Expand All @@ -86,19 +131,38 @@ def __init__(
`jax.numpy.float32` or `jax.numpy.float64` depending on whether JAX is in
64-bit mode.
"""
if mode not in ("ema", "batch"):
raise ValueError("Invalid mode, must be 'ema' or 'batch'.")
self.mode = mode
dtype = default_floating_dtype() if dtype is None else dtype
if channelwise_affine:
self.weight = jnp.ones((input_size,), dtype=dtype)
self.bias = jnp.zeros((input_size,), dtype=dtype)
else:
self.weight = None
self.bias = None
self.first_time_index = StateIndex(jnp.array(True))
init_buffers = (
jnp.empty((input_size,), dtype=dtype),
jnp.empty((input_size,), dtype=dtype),
)
self.state_index = StateIndex(init_buffers)
if mode == "ema":
self.ema_first_time_index = StateIndex(jnp.array(True))
init_buffers = (
jnp.empty((input_size,), dtype=dtype),
jnp.empty((input_size,), dtype=dtype),
)
self.ema_state_index = StateIndex(init_buffers)
self.batch_counter = None
self.batch_state_index = None
else:
self.batch_counter = StateIndex(jnp.array(0))
init_hidden = (
jnp.zeros((input_size,), dtype=dtype),
jnp.zeros((input_size,), dtype=dtype),
)
init_avg = (
jnp.zeros((input_size,), dtype=dtype),
jnp.zeros((input_size,), dtype=dtype),
)
self.batch_state_index = StateIndex((init_hidden, init_avg))
self.ema_first_time_index = None
self.ema_state_index = None
self.inference = inference
self.axis_name = axis_name
self.input_size = input_size
Expand Down Expand Up @@ -138,38 +202,85 @@ def __call__(
A `NameError` if no `vmap`s are placed around this operation, or if this vmap
does not have a matching `axis_name`.
"""

if inference is None:
inference = self.inference
if inference:
running_mean, running_var = state.get(self.state_index)

def _stats(y):
mean = jnp.mean(y)
mean = lax.pmean(mean, self.axis_name)
var = jnp.mean((y - mean) * jnp.conj(y - mean))
var = lax.pmean(var, self.axis_name)
var = jnp.maximum(0.0, var)
return mean, var

if self.mode == "ema":
assert (
self.ema_first_time_index is not None
and self.ema_state_index is not None
)
if inference:
running_mean, running_var = state.get(self.ema_state_index)
else:
first_time = state.get(self.ema_first_time_index)
state = state.set(self.ema_first_time_index, jnp.array(False))

batch_mean, batch_var = jax.vmap(_stats)(x)
running_mean, running_var = state.get(self.ema_state_index)
momentum = self.momentum
running_mean = (1 - momentum) * batch_mean + momentum * running_mean
running_var = (1 - momentum) * batch_var + momentum * running_var
# since jnp.array(0) == False
running_mean = lax.select(first_time, batch_mean, running_mean)
running_var = lax.select(first_time, batch_var, running_var)
state = state.set(self.ema_state_index, (running_mean, running_var))

def _norm(y, m, v, w, b):
out = (y - m) / jnp.sqrt(v + self.eps)
if self.channelwise_affine:
out = out * w + b
return out

out = jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias)
return out, state
else:
assert self.batch_state_index is not None and self.batch_counter is not None
if inference:
_, (mean, var) = state.get(self.batch_state_index)
else:
batch_mean, batch_var = jax.vmap(_stats)(x)
counter = state.get(self.batch_counter)
(hidden_mean, hidden_var), (running_mean, running_var) = state.get(
self.batch_state_index
)

decay = self.momentum
one = jnp.array(1.0, dtype=x.dtype)

# Update hidden_{mean,var}
new_hidden_mean = hidden_mean * decay + batch_mean * (one - decay)
new_hidden_var = hidden_var * decay + batch_var * (one - decay)

# Zero-debias approach: average_ = hidden_ / (1 - decay^counter)
# For simplicity we do the minimal version here (no warmup).
new_counter = counter + 1
decay_power = decay**new_counter
new_running_mean = new_hidden_mean / (one - decay_power)
new_running_var = new_hidden_var / (one - decay_power)

state = state.set(self.batch_counter, new_counter)
new_state_data = (
(new_hidden_mean, new_hidden_var),
(new_running_mean, new_running_var),
)
state = state.set(self.batch_state_index, new_state_data)

mean, var = (batch_mean, batch_var)

def _norm(y, m, v, w, b):
out = (y - m) / jnp.sqrt(v + self.eps)
if self.channelwise_affine:
out = out * w + b
return out

def _stats(y):
mean = jnp.mean(y)
mean = lax.pmean(mean, self.axis_name)
var = jnp.mean((y - mean) * jnp.conj(y - mean))
var = lax.pmean(var, self.axis_name)
var = jnp.maximum(0.0, var)
return mean, var

first_time = state.get(self.first_time_index)
state = state.set(self.first_time_index, jnp.array(False))

batch_mean, batch_var = jax.vmap(_stats)(x)
running_mean, running_var = state.get(self.state_index)
momentum = self.momentum
running_mean = (1 - momentum) * batch_mean + momentum * running_mean
running_var = (1 - momentum) * batch_var + momentum * running_var
running_mean = lax.select(first_time, batch_mean, running_mean)
running_var = lax.select(first_time, batch_var, running_var)
state = state.set(self.state_index, (running_mean, running_var))

def _norm(y, m, v, w, b):
out = (y - m) / jnp.sqrt(v + self.eps)
if self.channelwise_affine:
out = out * w + b
return out

out = jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias)
return out, state
out = jax.vmap(_norm)(x, mean, var, self.weight, self.bias)
return out, state
51 changes: 41 additions & 10 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,22 +940,28 @@ def test_group_norm(getkey):
gn = eqx.nn.GroupNorm(groups=4, channels=None, channelwise_affine=True)


def test_batch_norm(getkey):
@pytest.mark.parametrize("mode", ("ema", "batch"))
def test_batch_norm(getkey, mode):
x0 = jrandom.uniform(getkey(), (5,))
x1 = jrandom.uniform(getkey(), (10, 5))
x2 = jrandom.uniform(getkey(), (10, 5, 6))
x3 = jrandom.uniform(getkey(), (10, 5, 7, 8))

# Test that it works with a single vmap'd axis_name

bn = eqx.nn.BatchNorm(5, "batch")
bn = eqx.nn.BatchNorm(5, "batch", mode=mode)
state = eqx.nn.State(bn)
vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))

for x in (x1, x2, x3):
out, state = vbn(x, state)
assert out.shape == x.shape
running_mean, running_var = state.get(bn.state_index)
if mode == "ema":
assert bn.ema_state_index is not None
running_mean, running_var = state.get(bn.ema_state_index)
else:
assert bn.batch_state_index is not None
_, (running_mean, running_var) = state.get(bn.batch_state_index)
assert running_mean.shape == (5,)
assert running_var.shape == (5,)

Expand All @@ -976,28 +982,38 @@ def test_batch_norm(getkey):
in_axes=(0, None),
)(x2, state)
assert out.shape == x2.shape
running_mean, running_var = state.get(bn.state_index)
if mode == "ema":
assert bn.ema_state_index is not None
running_mean, running_var = state.get(bn.ema_state_index)
else:
assert bn.batch_state_index is not None
_, (running_mean, running_var) = state.get(bn.batch_state_index)
assert running_mean.shape == (10, 5)
assert running_var.shape == (10, 5)

# Test that it handles multiple axis_names

vvbn = eqx.nn.BatchNorm(6, ("batch1", "batch2"))
vvbn = eqx.nn.BatchNorm(6, ("batch1", "batch2"), mode=mode)
vvstate = eqx.nn.State(vvbn)
for axis_name in ("batch1", "batch2"):
vvbn = jax.vmap(
vvbn, axis_name=axis_name, in_axes=(0, None), out_axes=(0, None)
)
out, out_vvstate = vvbn(x2, vvstate)
assert out.shape == x2.shape
running_mean, running_var = out_vvstate.get(vvbn.state_index)
if mode == "ema":
assert vvbn.ema_state_index is not None
running_mean, running_var = out_vvstate.get(vvbn.ema_state_index)
else:
assert vvbn.batch_state_index is not None
_, (running_mean, running_var) = out_vvstate.get(vvbn.batch_state_index)
assert running_mean.shape == (6,)
assert running_var.shape == (6,)

# Test that it normalises

x1alt = jrandom.normal(jrandom.PRNGKey(5678), (10, 5)) # avoid flakey test
bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False)
bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False, mode=mode)
state = eqx.nn.State(bn)
vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
out, state = vbn(x1alt, state)
Expand All @@ -1008,9 +1024,19 @@ def test_batch_norm(getkey):

# Test that the statistics update during training
out, state = vbn(x1, state)
running_mean, running_var = state.get(bn.state_index)
if mode == "ema":
assert bn.ema_state_index is not None
running_mean, running_var = state.get(bn.ema_state_index)
else:
assert bn.batch_state_index is not None
_, (running_mean, running_var) = state.get(bn.batch_state_index)
out, state = vbn(3 * x1 + 10, state)
running_mean2, running_var2 = state.get(bn.state_index)
if mode == "ema":
assert bn.ema_state_index is not None
running_mean2, running_var2 = state.get(bn.ema_state_index)
else:
assert bn.batch_state_index is not None
_, (running_mean2, running_var2) = state.get(bn.batch_state_index)
assert not jnp.allclose(running_mean, running_mean2)
assert not jnp.allclose(running_var, running_var2)

Expand All @@ -1019,7 +1045,12 @@ def test_batch_norm(getkey):
ibn = eqx.nn.inference_mode(bn, value=True)
vibn = jax.vmap(ibn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
out, state = vibn(4 * x1 + 20, state)
running_mean3, running_var3 = state.get(bn.state_index)
if mode == "ema":
assert bn.ema_state_index is not None
running_mean3, running_var3 = state.get(bn.ema_state_index)
else:
assert bn.batch_state_index is not None
_, (running_mean3, running_var3) = state.get(bn.batch_state_index)
assert jnp.array_equal(running_mean2, running_mean3)
assert jnp.array_equal(running_var2, running_var3)

Expand Down