Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch BatchNorm #948

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

lockwo
Copy link
Contributor

@lockwo lockwo commented Feb 10, 2025

Revives #675, but ideally a smaller/simpler change (that doesn't require any math) that matches the patterns in flax/haiku (see: https://github.com/google-deepmind/dm-haiku/blob/main/haiku/_src/batch_norm.py#L42%23L206, https://github.com/google-deepmind/dm-haiku/blob/main/haiku/_src/moving_averages.py#L41%23L139). Has some hardcoded defaults (e.g. I don't support warmup iterations at all), but should help the users of batch norm. In addition to the stability in the initial example, here is an example on an AlphaZero model playing 9x9 Go:

Screenshot 2025-02-09 at 9 37 30 PM

This screenshot (and code) also comes from this PR: sotetsuk/pgx#1300

@lockwo lockwo changed the title Batch norm changes Batch BatchNorm Feb 10, 2025
@lockwo
Copy link
Contributor Author

lockwo commented Feb 10, 2025

Also, for anyone who wants to try out the comparisons, here's some example code that shows haiku vs equinox

code
import jax
from jax import numpy as jnp
import equinox as eqx
import haiku as hk


class Net(hk.Module):

    def __init__(
        self,
        name="net",
    ):
        super().__init__(name=name)

    def __call__(self, x, is_training):
        x = x.astype(jnp.float32)
        x = hk.BatchNorm(True, True, 0.9)(x, is_training)
        return x


class Neteqx(eqx.Module):
    norm: eqx.nn.BatchNorm

    def __init__(
        self,
        output_channels: int = 64,
    ):
        self.norm = eqx.nn.BatchNorm(
            output_channels, "batch", momentum=0.9, mode="batch"
        )

    def __call__(self, x, state):
        x = x.astype(jnp.float32)
        x = jnp.moveaxis(x, -1, 0)
        x, state = self.norm(x, state)
        x = jnp.moveaxis(x, 0, -1)
        return x, state


def forward_fn(x):
    net = Net()
    v = net(x, True)
    return v


def inf_fn(x):
    net = Net()
    v = net(x, False)
    return v


forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))
inf = hk.without_apply_rng(hk.transform_with_state(inf_fn))

ins = jnp.array([1, 2, 3, 4]).reshape((4, 1, 1, 1)) * jnp.ones((4, 2, 2, 3))
params, state = forward.init(jax.random.key(0), jnp.zeros_like(ins))
# print(params)
# print(hk.data_structures.tree_size(params))
print(state)
jax.tree.map(lambda x: print(x), state)

for i in range(2):
    out, state = forward.apply(params, state, ins[2 * i : 2 * (i + 1)])
    jax.tree.map(lambda x: print(x), state)

inf.apply(params, state, ins)

eq, s = eqx.nn.make_with_state(Neteqx)(3)
eq = eqx.tree_at(
    lambda x: x.norm.weight, eq, params["net/batch_norm"]["scale"].squeeze()
)
eq = eqx.tree_at(
    lambda x: x.norm.bias, eq, params["net/batch_norm"]["offset"].squeeze()
)

# print(sum(x.size for x in jax.tree.leaves(eqx.filter(eq, eqx.is_inexact_array))))
jax.tree.map(lambda x: print(x), s)

for i in range(2):
    out_eqx, s = eqx.filter_vmap(
        eq, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
    )(ins[2 * i : 2 * (i + 1)], s)
    jax.tree.map(lambda x: print(x), s)
    # print(out_eqx)

eqx.filter_vmap(
    eqx.nn.inference_mode(eq), in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(ins, s)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant