-
Notifications
You must be signed in to change notification settings - Fork 1
Add a mini-transformer example #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
83f1173
b1b3432
48644f7
9cc26dc
348517a
de228a9
acf1623
c8027f7
98d4a7b
dd898b8
6df502a
68b5a29
e2393c7
1a7393c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| tmp/ | ||
| checkpoints/ | ||
|
|
||
| *.bin | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| # /// script | ||
| # dependencies = [ | ||
| # "julax", | ||
| # ] | ||
| # | ||
| # [tool.uv.sources] | ||
| # julax = { path = "../", editable = true } | ||
| # /// | ||
|
|
||
| # Reproduce https://sdbuchanan.com/blog/jax-2/ | ||
|
|
||
| from functools import partial | ||
| import grain | ||
| import jax | ||
| import jax.numpy as jnp | ||
| import numpy as np | ||
| import optax | ||
| from jax.nn.initializers import truncated_normal | ||
| from julax.core import Learner, Trainer | ||
| from julax.einops import Rearrange | ||
| from julax.experiment import Experiment | ||
| from julax.layers import ( | ||
| Chain, | ||
| Linear, | ||
| LayerNorm, | ||
| Parallel, | ||
| Repeated, | ||
| RotaryEmbedding, | ||
| SkipConnection, | ||
| Embedding, | ||
| Unembedding, | ||
| ) | ||
| from julax.observers import default_observer | ||
| from julax.utils import identity | ||
|
|
||
|
|
||
| class FakeSource(grain.sources.RandomAccessDataSource): | ||
| def __init__(self, seq_len: int = 256) -> None: | ||
| self._seq_len = seq_len | ||
| self._data = np.array( | ||
| [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1] * 1024 | ||
| ) | ||
|
|
||
| def __getitem__(self, index: int): | ||
| return { | ||
| "input_ids": self._data[index : index + self._seq_len], | ||
| "target_labels": self._data[index + 1 : index + 1 + self._seq_len], | ||
| } | ||
|
|
||
| def __len__(self) -> int: | ||
| return len(self._data) - self._seq_len | ||
|
|
||
|
|
||
| def main( | ||
| seed: int = 5, | ||
| seq_len: int = 256, | ||
| global_batch_size: int = 128, | ||
| num_steps: int = 1000, | ||
| num_vocab: int = 10, | ||
| dim: int = 768, | ||
| num_heads: int = 12, | ||
| head_dim: int = 64, | ||
| num_layers: int = 2, | ||
| param_std: float = 0.02, | ||
| ): | ||
| return Experiment( | ||
| name="mini_transformer", | ||
| trainer=Trainer( | ||
| learner=Learner( | ||
| feature_name="input_ids", | ||
| label_name="target_labels", | ||
| model=Chain( | ||
| emb=Embedding( | ||
| in_dim=num_vocab, | ||
| out_dim=dim, | ||
| w_init=truncated_normal(stddev=param_std), | ||
| ), | ||
| blocks=Repeated( | ||
| n=num_layers, | ||
| layer=Chain( | ||
| attn=SkipConnection( | ||
| layer=Chain( | ||
| norm_attn=LayerNorm(dim=dim), | ||
| attn=Chain( | ||
| # qkv projection | ||
| Linear( | ||
| in_dim=dim, | ||
| out_dim=3 * dim, | ||
| w_init=truncated_normal(stddev=param_std), | ||
| b_init=None, | ||
| ), | ||
| Rearrange( | ||
| "B T (qkv N H) -> B T (qkv N) H", | ||
| B=global_batch_size, | ||
| T=seq_len, | ||
| qkv=3, | ||
| N=num_heads, | ||
| H=head_dim, | ||
| ), | ||
| partial( | ||
| jnp.split, indices_or_sections=3, axis=2 | ||
| ), | ||
| Parallel( | ||
| RotaryEmbedding( | ||
| embedding_dims=head_dim, | ||
| fprop_dtype=jnp.float32, | ||
| ), | ||
| RotaryEmbedding( | ||
| embedding_dims=head_dim, | ||
| fprop_dtype=jnp.float32, | ||
| ), | ||
| identity, | ||
| ), | ||
| lambda qkv: jax.nn.dot_product_attention( | ||
| *qkv, is_causal=True | ||
| ), | ||
| Rearrange( | ||
| "B T N H -> B T (N H)", | ||
| B=global_batch_size, | ||
| T=seq_len, | ||
| N=num_heads, | ||
| H=head_dim, | ||
| ), | ||
| Linear( | ||
| in_dim=dim, | ||
| out_dim=dim, | ||
| w_init=truncated_normal(stddev=param_std), | ||
| b_init=None, | ||
| ), | ||
| ), | ||
| ) | ||
| ), | ||
| mlp=SkipConnection( | ||
| layer=Chain( | ||
| norm_mlp=LayerNorm(dim=dim), | ||
| mlp=Chain( | ||
| up=Linear( | ||
| in_dim=dim, | ||
| out_dim=4 * dim, | ||
| w_init=truncated_normal(stddev=param_std), | ||
| b_init=None, | ||
| ), | ||
| act=jax.nn.gelu, | ||
| down=Linear( | ||
| in_dim=4 * dim, | ||
| out_dim=dim, | ||
| w_init=truncated_normal(stddev=param_std), | ||
| b_init=None, | ||
| ), | ||
| ), | ||
| ) | ||
| ), | ||
| ), | ||
| ), | ||
| unemb=Unembedding( | ||
| in_dim=dim, | ||
| out_dim=num_vocab, | ||
| w_init=truncated_normal(stddev=param_std), | ||
| ), | ||
| ), | ||
| loss_fn=optax.softmax_cross_entropy_with_integer_labels, | ||
| ), | ||
| optimizer=optax.sgd(0.01), | ||
| ), | ||
| dataset=( | ||
| grain.MapDataset.source(FakeSource(seq_len)) | ||
| .shuffle(seed=seed) | ||
| .repeat() | ||
| .batch(batch_size=global_batch_size) | ||
| .slice(slice(num_steps)) | ||
| .to_iter_dataset() | ||
| ), | ||
| observer=default_observer(), | ||
| ) | ||
|
|
||
|
|
||
| x = main() | ||
| x.run() | ||
| x.close() | ||
This file was deleted.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,8 +1,41 @@ | ||||||||||||||||||||||||||||||||||||||||||||
| from typing import TypeAlias, Any | ||||||||||||||||||||||||||||||||||||||||||||
| from pydantic import ConfigDict, RootModel | ||||||||||||||||||||||||||||||||||||||||||||
| from jax import Array | ||||||||||||||||||||||||||||||||||||||||||||
| from jax.sharding import PartitionSpec | ||||||||||||||||||||||||||||||||||||||||||||
| import plum | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| PRNG: TypeAlias = Array | ||||||||||||||||||||||||||||||||||||||||||||
| PyTree: TypeAlias = Any | ||||||||||||||||||||||||||||||||||||||||||||
| OutShardingType: TypeAlias = PartitionSpec | None | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # TODO: isinstance(jnp.dtype, jnp.float32) fails | ||||||||||||||||||||||||||||||||||||||||||||
| Dtype: TypeAlias = Any | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| dispatch = plum.Dispatcher(warn_redefinition=True) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| class FrozenDict(RootModel[dict]): | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
| model_config = ConfigDict(frozen=True) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def __getitem__(self, item): | ||||||||||||||||||||||||||||||||||||||||||||
| return self.root[item] | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def __iter__(self): | ||||||||||||||||||||||||||||||||||||||||||||
| return iter(self.root) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def keys(self): | ||||||||||||||||||||||||||||||||||||||||||||
| return self.root.keys() | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def values(self): | ||||||||||||||||||||||||||||||||||||||||||||
| return self.root.values() | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def items(self): | ||||||||||||||||||||||||||||||||||||||||||||
| return self.root.items() | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| def __hash__(self): | ||||||||||||||||||||||||||||||||||||||||||||
| return hash(frozenset(self.root.items())) | ||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+35
to
+36
|
||||||||||||||||||||||||||||||||||||||||||||
| def __hash__(self): | |
| return hash(frozenset(self.root.items())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but dict items may contain unhashable values (like lists or other dicts)
It will trigger an unhashable error at runtime. Left it for users.
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Missing __len__ method: The FrozenDict class implements dict-like methods but is missing __len__, which is commonly expected for dict-like objects. Consider adding def __len__(self): return len(self.root) for completeness.
Copilot
AI
Nov 28, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __hash__ implementation attempts to hash dictionary items directly using frozenset(self.root.items()). This will fail if any values in the dictionary are unhashable (e.g., lists, dicts, or mutable objects). Consider handling unhashable values or documenting this limitation. A safer approach might be to make values recursively frozen or to raise a more descriptive error.
| return hash(frozenset(self.root.items())) | |
| # Recursively freeze the dictionary to ensure all values are hashable | |
| return hash(self._recursive_freeze(self.root)) | |
| @staticmethod | |
| def _recursive_freeze(obj): | |
| """Recursively convert obj to a hashable structure.""" | |
| if isinstance(obj, dict): | |
| # Sort items to ensure consistent ordering | |
| return tuple(sorted((k, FrozenDict._recursive_freeze(v)) for k, v in obj.items())) | |
| elif isinstance(obj, (list, tuple)): | |
| return tuple(FrozenDict._recursive_freeze(v) for v in obj) | |
| elif isinstance(obj, set): | |
| return frozenset(FrozenDict._recursive_freeze(v) for v in obj) | |
| # Add more types here if needed (e.g., numpy arrays) | |
| # If the object is already hashable, return as is | |
| try: | |
| hash(obj) | |
| except TypeError: | |
| raise TypeError(f"Unhashable type encountered in FrozenDict: {type(obj)}") | |
| return obj |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Missing documentation: The
FakeSourceclass lacks a docstring explaining its purpose as a simple synthetic data source for testing the transformer. Consider adding documentation to explain the repeating pattern in_dataand why this particular sequence is used for training.