Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
tmp/
checkpoints/

*.bin
Expand Down
179 changes: 179 additions & 0 deletions experiments/mini_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# /// script
# dependencies = [
# "julax",
# ]
#
# [tool.uv.sources]
# julax = { path = "../", editable = true }
# ///

# Reproduce https://sdbuchanan.com/blog/jax-2/

from functools import partial
import grain
import jax
import jax.numpy as jnp
import numpy as np
import optax
from jax.nn.initializers import truncated_normal
from julax.core import Learner, Trainer
from julax.einops import Rearrange
from julax.experiment import Experiment
from julax.layers import (
Chain,
Linear,
LayerNorm,
Parallel,
Repeated,
RotaryEmbedding,
SkipConnection,
Embedding,
Unembedding,
)
from julax.observers import default_observer
from julax.utils import identity


class FakeSource(grain.sources.RandomAccessDataSource):
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Missing documentation: The FakeSource class lacks a docstring explaining its purpose as a simple synthetic data source for testing the transformer. Consider adding documentation to explain the repeating pattern in _data and why this particular sequence is used for training.

Suggested change
class FakeSource(grain.sources.RandomAccessDataSource):
class FakeSource(grain.sources.RandomAccessDataSource):
"""
A simple synthetic data source for testing transformer models.
The data consists of a repeating pattern:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1]
repeated 1024 times to form a long sequence.
This predictable, non-random sequence is useful for sequence modeling tasks,
allowing the model to learn to predict the next token in a known pattern.
"""

Copilot uses AI. Check for mistakes.
def __init__(self, seq_len: int = 256) -> None:
self._seq_len = seq_len
self._data = np.array(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1] * 1024
)

def __getitem__(self, index: int):
return {
"input_ids": self._data[index : index + self._seq_len],
"target_labels": self._data[index + 1 : index + 1 + self._seq_len],
}

def __len__(self) -> int:
return len(self._data) - self._seq_len


def main(
seed: int = 5,
seq_len: int = 256,
global_batch_size: int = 128,
num_steps: int = 1000,
num_vocab: int = 10,
dim: int = 768,
num_heads: int = 12,
head_dim: int = 64,
num_layers: int = 2,
param_std: float = 0.02,
):
return Experiment(
name="mini_transformer",
trainer=Trainer(
learner=Learner(
feature_name="input_ids",
label_name="target_labels",
model=Chain(
emb=Embedding(
in_dim=num_vocab,
out_dim=dim,
w_init=truncated_normal(stddev=param_std),
),
blocks=Repeated(
n=num_layers,
layer=Chain(
attn=SkipConnection(
layer=Chain(
norm_attn=LayerNorm(dim=dim),
attn=Chain(
# qkv projection
Linear(
in_dim=dim,
out_dim=3 * dim,
w_init=truncated_normal(stddev=param_std),
b_init=None,
),
Rearrange(
"B T (qkv N H) -> B T (qkv N) H",
B=global_batch_size,
T=seq_len,
qkv=3,
N=num_heads,
H=head_dim,
),
partial(
jnp.split, indices_or_sections=3, axis=2
),
Parallel(
RotaryEmbedding(
embedding_dims=head_dim,
fprop_dtype=jnp.float32,
),
RotaryEmbedding(
embedding_dims=head_dim,
fprop_dtype=jnp.float32,
),
identity,
),
lambda qkv: jax.nn.dot_product_attention(
*qkv, is_causal=True
),
Rearrange(
"B T N H -> B T (N H)",
B=global_batch_size,
T=seq_len,
N=num_heads,
H=head_dim,
),
Linear(
in_dim=dim,
out_dim=dim,
w_init=truncated_normal(stddev=param_std),
b_init=None,
),
),
)
),
mlp=SkipConnection(
layer=Chain(
norm_mlp=LayerNorm(dim=dim),
mlp=Chain(
up=Linear(
in_dim=dim,
out_dim=4 * dim,
w_init=truncated_normal(stddev=param_std),
b_init=None,
),
act=jax.nn.gelu,
down=Linear(
in_dim=4 * dim,
out_dim=dim,
w_init=truncated_normal(stddev=param_std),
b_init=None,
),
),
)
),
),
),
unemb=Unembedding(
in_dim=dim,
out_dim=num_vocab,
w_init=truncated_normal(stddev=param_std),
),
),
loss_fn=optax.softmax_cross_entropy_with_integer_labels,
),
optimizer=optax.sgd(0.01),
),
dataset=(
grain.MapDataset.source(FakeSource(seq_len))
.shuffle(seed=seed)
.repeat()
.batch(batch_size=global_batch_size)
.slice(slice(num_steps))
.to_iter_dataset()
),
observer=default_observer(),
)


x = main()
x.run()
x.close()
12 changes: 2 additions & 10 deletions experiments/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@

import logging

from datetime import datetime
import os

import grain
import jax
from jax.nn.initializers import truncated_normal
import optax
import orbax.checkpoint as ocp
import tensorflow_datasets as tfds

from julax import (
Expand Down Expand Up @@ -62,17 +59,11 @@ def evaluate(x: Experiment, p: Param, s: State):
n_total += 32
acc = n_correct / n_total

logging.info(f"Accuracy at step {s['trainer']['step']}: {acc}")
logging.info(f"Accuracy at step {s['step']}: {acc}")


E = Experiment(
name="mnist",
checkpoint_manager=ocp.CheckpointManager(
directory=os.path.join(
os.getcwd(), "checkpoints", datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
),
options=ocp.CheckpointManagerOptions(save_interval_steps=100),
),
trainer=Trainer(
learner=Learner(
model=Chain(
Expand Down Expand Up @@ -114,3 +105,4 @@ def evaluate(x: Experiment, p: Param, s: State):
)

E.run()
E.close()
31 changes: 0 additions & 31 deletions experiments/transformer.py

This file was deleted.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "julax"
version = "0.0.3-dev"
version = "0.0.3"
description = "Just Layers over JAX"
readme = "README.md"
authors = [
Expand Down Expand Up @@ -29,6 +29,9 @@ dev = [
"pytest>=8.3.2",
"pytest-cov>=5.0.0",
]
tpu = [
"jax[tpu]>=0.7.2",
]

[tool.ruff.lint]
ignore = ["E741", "F811"]
Expand Down
33 changes: 33 additions & 0 deletions src/julax/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,41 @@
from typing import TypeAlias, Any
from pydantic import ConfigDict, RootModel
from jax import Array
from jax.sharding import PartitionSpec
import plum

PRNG: TypeAlias = Array
PyTree: TypeAlias = Any
OutShardingType: TypeAlias = PartitionSpec | None

# TODO: isinstance(jnp.dtype, jnp.float32) fails
Dtype: TypeAlias = Any

dispatch = plum.Dispatcher(warn_redefinition=True)


class FrozenDict(RootModel[dict]):
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class implements hash, but does not implement eq.

Copilot uses AI. Check for mistakes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback

model_config = ConfigDict(frozen=True)

def __getitem__(self, item):
return self.root[item]

def __iter__(self):
return iter(self.root)

def keys(self):
return self.root.keys()

def values(self):
return self.root.values()

def items(self):
return self.root.items()

def __hash__(self):
return hash(frozenset(self.root.items()))
Comment on lines +35 to +36
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FrozenDict.__hash__ implementation has a bug: it attempts to hash a dict's items directly via frozenset(self.root.items()), but dict items may contain unhashable values (like lists or other dicts). This will raise a TypeError when the dict contains unhashable values.

Since this is a Pydantic RootModel[dict] and dicts are generally unhashable, consider either:

  1. Removing the __hash__ method entirely (letting it raise TypeError when hashing is attempted)
  2. Implementing a more robust hash that handles unhashable values
  3. Using id(self) if object identity is sufficient
Suggested change
def __hash__(self):
return hash(frozenset(self.root.items()))

Copilot uses AI. Check for mistakes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but dict items may contain unhashable values (like lists or other dicts)

It will trigger an unhashable error at runtime. Left it for users.

Comment on lines +17 to +36
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Missing __len__ method: The FrozenDict class implements dict-like methods but is missing __len__, which is commonly expected for dict-like objects. Consider adding def __len__(self): return len(self.root) for completeness.

Copilot uses AI. Check for mistakes.

Comment on lines +36 to +37
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __hash__ implementation attempts to hash dictionary items directly using frozenset(self.root.items()). This will fail if any values in the dictionary are unhashable (e.g., lists, dicts, or mutable objects). Consider handling unhashable values or documenting this limitation. A safer approach might be to make values recursively frozen or to raise a more descriptive error.

Suggested change
return hash(frozenset(self.root.items()))
# Recursively freeze the dictionary to ensure all values are hashable
return hash(self._recursive_freeze(self.root))
@staticmethod
def _recursive_freeze(obj):
"""Recursively convert obj to a hashable structure."""
if isinstance(obj, dict):
# Sort items to ensure consistent ordering
return tuple(sorted((k, FrozenDict._recursive_freeze(v)) for k, v in obj.items()))
elif isinstance(obj, (list, tuple)):
return tuple(FrozenDict._recursive_freeze(v) for v in obj)
elif isinstance(obj, set):
return frozenset(FrozenDict._recursive_freeze(v) for v in obj)
# Add more types here if needed (e.g., numpy arrays)
# If the object is already hashable, return as is
try:
hash(obj)
except TypeError:
raise TypeError(f"Unhashable type encountered in FrozenDict: {type(obj)}")
return obj

Copilot uses AI. Check for mistakes.
def __eq__(self, other):
if isinstance(other, FrozenDict):
return self.root == other.root
return self.root == other
19 changes: 11 additions & 8 deletions src/julax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

import jax
import jax.numpy as jnp
from jax import jit, value_and_grad, Array
from jax import jit, value_and_grad

#####

from julax.base import PRNG, PyTree, dispatch
from julax.base import PRNG, Dtype, OutShardingType, PyTree, dispatch

# TODO: use RootModel[dict] for better customization
# Or maybe SimpleNamespace?
Expand All @@ -26,6 +26,10 @@


class LayerBase(BaseModel, ABC):
param_dtype: Dtype | None = None
param_sharding: OutShardingType = None
out_sharding: OutShardingType = None

model_config = ConfigDict(
arbitrary_types_allowed=True,
frozen=True,
Expand Down Expand Up @@ -134,7 +138,7 @@ def to_layer(x):


class Learner(LayerBase):
loss_fn: Callable[[PyTree, PyTree], Array]
loss_fn: Callable[[PyTree, PyTree], Any]
model: LayerBase
agg: Callable = jnp.mean
feature_name: str = "feature"
Expand All @@ -154,7 +158,7 @@ class Trainer(LayerBase):
optimizer: Any

def state(self, rng: PRNG) -> State:
return State(optimizer=None, step=0, loss=0.0)
return State(optimizer=None, loss=0.0)

@dispatch
def init(
Expand All @@ -165,11 +169,9 @@ def init(

def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]:
loss, state = self.learner(x, p["learner"], s["learner"])
return loss, State(
learner=state, optimizer=s["optimizer"], step=s["step"] + 1, loss=loss
)
return loss, State(learner=state, optimizer=s["optimizer"], loss=loss)

@partial(jit, static_argnums=0)
@partial(jit, static_argnums=0, donate_argnames=("p", "s"))
def forward_and_backward(
self, x: PyTree, p: Param, s: State
) -> tuple[Param, State]:
Expand All @@ -178,5 +180,6 @@ def forward_and_backward(
P = optax.apply_updates(p, updates)
return P, S

@dispatch
def __call__(self, x: PyTree, p: Param, s: State) -> tuple[Param, State]:
return self.forward_and_backward(x, p, s)
4 changes: 3 additions & 1 deletion src/julax/einops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import jax.numpy as jnp
from jax.nn.initializers import Initializer
from pydantic import computed_field

from julax.base import FrozenDict
from .core import LayerBase, Param, PyTree, State, PRNG


Expand All @@ -26,7 +28,7 @@ def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]:

class Rearrange(LayerBase):
pattern: str
sizes: dict
sizes: FrozenDict

def __init__(self, pattern: str, **kwargs):
super().__init__(pattern=pattern, sizes=kwargs)
Expand Down
Loading