Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change examples away from Haiku #1300

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
"python.linting.flake8Enabled": true,
"python.defaultInterpreterPath": "${workspace}/venv/bin/python3",
"python.terminal.activateEnvInCurrentTerminal": true
}
}
178 changes: 111 additions & 67 deletions examples/alphazero/network.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,137 @@
# We referred to Haiku's ResNet implementation:
# https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/nets/resnet.py

import haiku as hk
import equinox as eqx
import jax
import jax.numpy as jnp


class BlockV1(hk.Module):
def __init__(self, num_channels, name="BlockV1"):
super(BlockV1, self).__init__(name=name)
self.num_channels = num_channels
class BlockV1(eqx.Module):
conv1: eqx.nn.Conv2d
conv2: eqx.nn.Conv2d
norm1: eqx.nn.BatchNorm
norm2: eqx.nn.BatchNorm

def __call__(self, x, is_training, test_local_stats):
def __init__(self, in_channels, out_channels, key):
keys = jax.random.split(key, 2)
self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, padding="SAME", kernel_size=3, key=keys[0])
self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, padding="SAME", kernel_size=3, key=keys[1])
self.norm1 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9, mode="batch")
self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9, mode="batch")

def __call__(self, x, state):
i = x
x = hk.Conv2D(self.num_channels, kernel_shape=3)(x)
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
x = self.conv1(x)
x, state = self.norm1(x, state)
x = jax.nn.relu(x)
x = hk.Conv2D(self.num_channels, kernel_shape=3)(x)
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
return jax.nn.relu(x + i)
x = self.conv2(x)
x, state = self.norm2(x, state)
return jax.nn.relu(x + i), state


class BlockV2(eqx.Module):
conv1: eqx.nn.Conv2d
conv2: eqx.nn.Conv2d
norm1: eqx.nn.BatchNorm
norm2: eqx.nn.BatchNorm

class BlockV2(hk.Module):
def __init__(self, num_channels, name="BlockV2"):
super(BlockV2, self).__init__(name=name)
self.num_channels = num_channels
def __init__(self, in_channels, out_channels, key):
keys = jax.random.split(key, 2)
self.conv1 = eqx.nn.Conv2d(in_channels, out_channels, padding="SAME", kernel_size=3, key=keys[0])
self.conv2 = eqx.nn.Conv2d(out_channels, out_channels, padding="SAME", kernel_size=3, key=keys[1])
self.norm1 = eqx.nn.BatchNorm(in_channels, "batch", momentum=0.9, mode="batch")
self.norm2 = eqx.nn.BatchNorm(out_channels, "batch", momentum=0.9, mode="batch")

def __call__(self, x, is_training, test_local_stats):
def __call__(self, x, state):
i = x
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
x, state = self.norm1(x, state)
x = jax.nn.relu(x)
x = hk.Conv2D(self.num_channels, kernel_shape=3)(x)
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
x = self.conv1(x)
x, state = self.norm2(x, state)
x = jax.nn.relu(x)
x = hk.Conv2D(self.num_channels, kernel_shape=3)(x)
return x + i
x = self.conv2(x)
return x + i, state


class AZNet(eqx.Module):

class AZNet(hk.Module):
"""AlphaZero NN architecture."""
init_layers: list
resnet: list
post_resnet: list
policy_head: list
value_head: list

def __init__(
self,
num_actions,
num_channels: int = 64,
input_channels,
key,
output_channels: int = 64,
num_blocks: int = 5,
resnet_v2: bool = True,
name="az_net",
):
super().__init__(name=name)
self.num_actions = num_actions
self.num_channels = num_channels
self.num_blocks = num_blocks
self.resnet_v2 = resnet_v2
self.resnet_cls = BlockV2 if resnet_v2 else BlockV1

def __call__(self, x, is_training, test_local_stats):
resnet_cls = BlockV2 if resnet_v2 else BlockV1

keys = jax.random.split(key, num_blocks + 5)
self.init_layers = [eqx.nn.Conv2d(input_channels, output_channels, kernel_size=3, padding="SAME", key=keys[0])]
if not resnet_v2:
self.init_layers += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9, mode="batch"), jax.nn.relu]
self.resnet = [resnet_cls(output_channels, output_channels, keys[i + 1]) for i in range(num_blocks)]
self.post_resnet = []
if resnet_v2:
self.post_resnet += [eqx.nn.BatchNorm(output_channels, "batch", momentum=0.9, mode="batch"), jax.nn.relu]
self.policy_head = [
eqx.nn.Conv2d(output_channels, 2, kernel_size=1, padding="SAME", key=keys[num_blocks + 1]),
eqx.nn.BatchNorm(2, "batch", momentum=0.9, mode="batch"),
jax.nn.relu,
lambda x: x.flatten(),
# TODO: infer from inputs
eqx.nn.Linear(162, num_actions, key=keys[num_blocks + 2]),
]

self.value_head = [
eqx.nn.Conv2d(output_channels, 1, kernel_size=1, padding="SAME", key=keys[num_blocks + 3]),
eqx.nn.BatchNorm(1, "batch", momentum=0.9, mode="batch"),
jax.nn.relu,
lambda x: x.flatten(),
eqx.nn.Linear(81, output_channels, key=keys[num_blocks + 2]),
jax.nn.relu,
eqx.nn.Linear(output_channels, 1, key=keys[num_blocks + 2]),
jnp.tanh,
jnp.squeeze,
]

def __call__(self, x, state):
x = x.astype(jnp.float32)
x = hk.Conv2D(self.num_channels, kernel_shape=3)(x)

if not self.resnet_v2:
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
x = jax.nn.relu(x)

for i in range(self.num_blocks):
x = self.resnet_cls(self.num_channels, name=f"block_{i}")(
x, is_training, test_local_stats
)

if self.resnet_v2:
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
x = jax.nn.relu(x)

# policy head
logits = hk.Conv2D(output_channels=2, kernel_shape=1)(x)
logits = hk.BatchNorm(True, True, 0.9)(logits, is_training, test_local_stats)
logits = jax.nn.relu(logits)
logits = hk.Flatten()(logits)
logits = hk.Linear(self.num_actions)(logits)

# value head
v = hk.Conv2D(output_channels=1, kernel_shape=1)(x)
v = hk.BatchNorm(True, True, 0.9)(v, is_training, test_local_stats)
v = jax.nn.relu(v)
v = hk.Flatten()(v)
v = hk.Linear(self.num_channels)(v)
v = jax.nn.relu(v)
v = hk.Linear(1)(v)
v = jnp.tanh(v)
v = v.reshape((-1,))

return logits, v
x = jnp.moveaxis(x, -1, 0)

for layer in self.init_layers:
if isinstance(layer, eqx.nn.StatefulLayer):
x, state = layer(x, state)
else:
x = layer(x)

for layer in self.resnet:
x, state = layer(x, state)

for layer in self.post_resnet:
if isinstance(layer, eqx.nn.StatefulLayer):
x, state = layer(x, state)
else:
x = layer(x)

logits = x.copy()
for layer in self.policy_head:
if isinstance(layer, eqx.nn.StatefulLayer):
logits, state = layer(logits, state)
else:
logits = layer(logits)

v = x.copy()
for layer in self.value_head:
if isinstance(layer, eqx.nn.StatefulLayer):
v, state = layer(v, state)
else:
v = layer(v)

return (logits, v), state
3 changes: 2 additions & 1 deletion examples/alphazero/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
pgx>=2.0.0
dm-haiku
equinox
mctx
optax
wandb
omegaconf
pydantic
cloundpickle
Loading