Skip to content

Conversation

@findmyway
Copy link
Member

No description provided.

Copilot AI review requested due to automatic review settings November 28, 2025 02:02
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a mini-transformer example to demonstrate the framework's capabilities for building and training transformer models. The changes refactor state management by moving the step counter from the trainer to the experiment level, add new layer primitives for building transformers, and enhance support for distributed training with JAX sharding.

Key changes:

  • Moved step tracking from s["trainer"]["step"] to s["step"] at the Experiment level for cleaner state architecture
  • Added SkipConnection, Repeated, and Unembedding layer classes to support transformer architectures
  • Enhanced layers with param_dtype, param_sharding, and out_sharding attributes for distributed training and mixed precision support

Reviewed changes

Copilot reviewed 9 out of 10 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
src/julax/observers.py Updated observers to use top-level s["step"] instead of s["trainer"]["step"]; switched LossLogger to use jax.debug.print for JIT compatibility
src/julax/layers.py Added SkipConnection, Repeated, and Unembedding layers; updated existing layers with sharding/dtype support; refactored LayerNorm to use jax.nn.standardize
src/julax/experiment.py Moved step counter to top-level state; made checkpoint_manager optional; added close() method; updated run() to return final params/state
src/julax/einops.py Changed Rearrange.sizes type from dict to FrozenDict for immutability
src/julax/core.py Added param_dtype, param_sharding, out_sharding to LayerBase; removed step tracking from Trainer state; added donate_argnames to JIT decorator
src/julax/base.py Added FrozenDict class and OutShardingType type alias for sharding support
experiments/transformer.py Removed old transformer stub (replaced by mini_transformer.py)
experiments/mnist.py Removed hardcoded checkpoint manager; updated to use top-level step in evaluate function; added E.close() call
experiments/mini_transformer.py New file implementing a complete mini-transformer example with attention, MLP blocks, and proper initialization
.gitignore Added tmp/ directory to gitignore

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +40 to +52
class Repeated(LayerBase):
n: int
layer: LayerLike

def sublayers(self) -> dict:
return {f"layer_{i}": self.layer for i in range(self.n)}

def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]:
S = State()
o = x
for i in range(self.n):
o, S[f"layer_{i}"] = self.layer(o, p[f"layer_{i}"], s[f"layer_{i}"])
return o, S
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The Repeated class creates independent copies of parameters for each repetition (via sublayers() returning the same layer instance multiple times). This means each repetition will have different, independently trained parameters.

If weight sharing across repetitions is desired (common in some architectures), the current implementation won't support that. If independent parameters per layer is intended (as in typical transformer blocks), this is correct but should be documented.

Consider adding a docstring to clarify that each repetition has independent parameters.

Copilot uses AI. Check for mistakes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback

Comment on lines +32 to +33
def __hash__(self):
return hash(frozenset(self.root.items()))
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FrozenDict.__hash__ implementation has a bug: it attempts to hash a dict's items directly via frozenset(self.root.items()), but dict items may contain unhashable values (like lists or other dicts). This will raise a TypeError when the dict contains unhashable values.

Since this is a Pydantic RootModel[dict] and dicts are generally unhashable, consider either:

  1. Removing the __hash__ method entirely (letting it raise TypeError when hashing is attempted)
  2. Implementing a more robust hash that handles unhashable values
  3. Using id(self) if object identity is sufficient
Suggested change
def __hash__(self):
return hash(frozenset(self.root.items()))

Copilot uses AI. Check for mistakes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but dict items may contain unhashable values (like lists or other dicts)

It will trigger an unhashable error at runtime. Left it for users.

Comment on lines +30 to +42
class SkipConnection(LayerBase):
layer: LayerLike
connection: Callable = jnp.add

def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]:
S = State()
o, S["layer"] = self.layer(x, p["layer"], s["layer"])
return self.connection(o, x), S


class Repeated(LayerBase):
n: int
layer: LayerLike
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new SkipConnection and Repeated layer classes lack docstrings, unlike other layer classes in the file (e.g., Branch has """1 -> N""" and Parallel has """N -> N"""). Adding docstrings would help users understand their purpose and usage.

Consider adding docstrings:

class SkipConnection(LayerBase):
    """Applies a layer and combines its output with the input using a connection function (default: addition)."""
    # ...

class Repeated(LayerBase):
    """Applies the same layer n times sequentially."""
    # ...

Copilot uses AI. Check for mistakes.
Comment on lines +69 to +100
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]:
h = x
S = {}
S = State()
for name, layer in zip(self.names, self.layers):
h, S[name] = layer(h, p[name], s[name])
return h, State(**S)
return h, S


class Branch(NamedLayers):
"""1 -> N"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]:
O = {}
S = {}
S = State()
for name, layer in zip(self.names, self.layers):
O[name], S[name] = layer(x, p[name], s[name])
# ??? return dict?
return tuple(O.values()), State(**S)
return tuple(O.values()), S


class Parallel(NamedLayers):
"""N -> N"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The __init__ methods in Chain, Branch, and Parallel classes appear to be redundant as they only call super().__init__(*args, **kwargs) with no additional logic. Since these classes inherit from NamedLayers which already defines a custom __init__, and Pydantic's BaseModel automatically handles initialization, these explicit __init__ methods are unnecessary.

Consider removing these redundant __init__ methods unless there's a specific reason for their presence.

Copilot uses AI. Check for mistakes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intended to avoid pylint complaints.

return p["w"].at[x].get(out_sharding=self.out_sharding), s


class Unembedding(Embedding):
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The Unembedding class inherits from Embedding, which may cause confusion about the meaning of in_dim and out_dim. In the usage (line 140-141 in mini_transformer.py), Unembedding(in_dim=dim, out_dim=num_vocab) creates a weight matrix of shape (dim, num_vocab), but the parameter names suggest the opposite direction compared to how they're used in the einsum operation.

For Unembedding, the einsum "bld,dn->bln" expects:

  • Input: (batch, length, dim)
  • Weight: (dim, num_vocab)

This matches the current implementation, but the inherited parameter names in_dim=dim and out_dim=num_vocab are semantically backward from a typical linear layer perspective where in_dim would be the input feature dimension and out_dim would be the output dimension.

Consider either:

  1. Adding a docstring to clarify the parameter semantics
  2. Overriding __init__ to swap or rename parameters for clarity
  3. Not inheriting from Embedding and defining parameters explicitly
Suggested change
class Unembedding(Embedding):
class Unembedding(Embedding):
"""
Unembedding layer: projects from hidden dimension to vocabulary logits.
Args:
dim (int): Input feature dimension (hidden size).
num_vocab (int): Output vocabulary size.
w_init (Initializer, optional): Initializer for the weight matrix.
The weight matrix has shape (dim, num_vocab), matching the einsum "bld,dn->bln".
"""
def __init__(self, dim: int, num_vocab: int, w_init: Initializer = variance_scaling(1.0, "fan_in", "normal", out_axis=0)):
super().__init__(in_dim=dim, out_dim=num_vocab, w_init=w_init)

Copilot uses AI. Check for mistakes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot create an issue to review it later.

dispatch = plum.Dispatcher(warn_redefinition=True)


class FrozenDict(RootModel[dict]):
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class implements hash, but does not implement eq.

Copilot uses AI. Check for mistakes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback

Copy link
Contributor

Copilot AI commented Nov 28, 2025

@findmyway I've opened a new pull request, #2, to work on those changes. Once the pull request is ready, I'll request review from you.

Copilot AI review requested due to automatic review settings November 28, 2025 02:14
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 10 changed files in this pull request and generated 13 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +40 to +52
class Repeated(LayerBase):
n: int
layer: LayerLike

def sublayers(self) -> dict:
return {f"layer_{i}": self.layer for i in range(self.n)}

def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]:
S = State()
o = x
for i in range(self.n):
o, S[f"layer_{i}"] = self.layer(o, p[f"layer_{i}"], s[f"layer_{i}"])
return o, S
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test coverage: The new Repeated layer lacks test coverage. Consider adding tests to verify that it correctly applies the same layer n times in sequence, properly manages parameter and state namespacing (layer_0, layer_1, etc.), and that the sublayers() method returns the correct dictionary structure.

Copilot uses AI. Check for mistakes.
Comment on lines +207 to +209
class Unembedding(Embedding):
def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]:
return p["w"][x], s
return jnp.einsum("bld,dn->bln", x, p["w"], out_sharding=self.out_sharding), s
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test coverage: The new Unembedding layer lacks test coverage. Consider adding tests to verify the einsum operation produces the correct output shape and values, and that it properly supports the out_sharding parameter for distributed training.

Copilot uses AI. Check for mistakes.
Comment on lines +14 to +33
class FrozenDict(RootModel[dict]):
model_config = ConfigDict(frozen=True)

def __getitem__(self, item):
return self.root[item]

def __iter__(self):
return iter(self.root)

def keys(self):
return self.root.keys()

def values(self):
return self.root.values()

def items(self):
return self.root.items()

def __hash__(self):
return hash(frozenset(self.root.items()))
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Missing __len__ method: The FrozenDict class implements dict-like methods but is missing __len__, which is commonly expected for dict-like objects. Consider adding def __len__(self): return len(self.root) for completeness.

Copilot uses AI. Check for mistakes.
epsilon: float = 1e-5
w_init: Initializer = ones
b_init: Initializer = zeros
compute_dtype: jnp.dtype | None = None
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Missing documentation: The new compute_dtype parameter in LayerNorm lacks documentation. Consider adding a comment or docstring explaining its purpose (likely for controlling precision during normalization computation) and what happens when it's None.

Copilot uses AI. Check for mistakes.
from julax.observers import default_observer


class FakeSource(grain.sources.RandomAccessDataSource):
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Missing documentation: The FakeSource class lacks a docstring explaining its purpose as a simple synthetic data source for testing the transformer. Consider adding documentation to explain the repeating pattern in _data and why this particular sequence is used for training.

Suggested change
class FakeSource(grain.sources.RandomAccessDataSource):
class FakeSource(grain.sources.RandomAccessDataSource):
"""
A simple synthetic data source for testing transformer models.
The data consists of a repeating pattern:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 8, 7, 6, 5, 4, 3, 2, 1]
repeated 1024 times to form a long sequence.
This predictable, non-random sequence is useful for sequence modeling tasks,
allowing the model to learn to predict the next token in a known pattern.
"""

Copilot uses AI. Check for mistakes.
Comment on lines 87 to 106
Rearrange(
"B T (qkv N H) -> B T (qkv N) H",
B=global_batch_size,
T=seq_len,
qkv=3,
N=num_heads,
H=head_dim,
),
partial(
jnp.split, indices_or_sections=3, axis=2
),
lambda qkv: jax.nn.dot_product_attention(
*qkv, is_causal=True
),
Rearrange(
"B T N H -> B T (N H)",
B=global_batch_size,
T=seq_len,
N=num_heads,
H=head_dim,
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded batch size in Rearrange patterns: The transformer architecture has hardcoded B=global_batch_size in the Rearrange operations (lines 89, 103). This makes the model inflexible to different batch sizes at inference time. Consider using einops' ability to infer dimensions or making the model work with variable batch sizes by removing the explicit B constraint or using a more flexible approach.

Copilot uses AI. Check for mistakes.
Comment on lines +98 to +99
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant __init__ method: This __init__ method only calls super().__init__() with the same arguments and doesn't add any functionality. Since NamedLayers already defines the __init__ behavior, this override is unnecessary and can be removed.

Copilot uses AI. Check for mistakes.
Comment on lines +30 to +37
class SkipConnection(LayerBase):
layer: LayerLike
connection: Callable = jnp.add

def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]:
S = State()
o, S["layer"] = self.layer(x, p["layer"], s["layer"])
return self.connection(o, x), S
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test coverage: The new SkipConnection layer lacks test coverage. Consider adding tests to verify that it correctly combines the output of the wrapped layer with the input using the specified connection function, and that it properly handles state propagation.

Copilot uses AI. Check for mistakes.
Comment on lines +69 to +70
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant __init__ method: This __init__ method only calls super().__init__() with the same arguments and doesn't add any functionality. Since NamedLayers already defines the __init__ behavior, this override is unnecessary and can be removed.

Copilot uses AI. Check for mistakes.
w_rng,
(self.dim,),
dtype=self.param_dtype,
out_sharding=self.out_sharding,
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential inconsistency: The w parameter uses out_sharding instead of param_sharding. This is inconsistent with the b parameter which correctly uses param_sharding. The weight parameter sharding should likely use self.param_sharding instead of self.out_sharding.

Suggested change
out_sharding=self.out_sharding,
out_sharding=(
None if self.param_sharding is None else P(self.param_sharding[-1])
),

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Copilot AI commented Nov 28, 2025

@findmyway I've opened a new pull request, #3, to work on those changes. Once the pull request is ready, I'll request review from you.

Copy link
Contributor

Copilot AI commented Nov 28, 2025

@findmyway I've opened a new pull request, #4, to work on those changes. Once the pull request is ready, I'll request review from you.

Copilot AI review requested due to automatic review settings November 28, 2025 13:07
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 11 changed files in this pull request and generated 10 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +36 to +37
return hash(frozenset(self.root.items()))

Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __hash__ implementation attempts to hash dictionary items directly using frozenset(self.root.items()). This will fail if any values in the dictionary are unhashable (e.g., lists, dicts, or mutable objects). Consider handling unhashable values or documenting this limitation. A safer approach might be to make values recursively frozen or to raise a more descriptive error.

Suggested change
return hash(frozenset(self.root.items()))
# Recursively freeze the dictionary to ensure all values are hashable
return hash(self._recursive_freeze(self.root))
@staticmethod
def _recursive_freeze(obj):
"""Recursively convert obj to a hashable structure."""
if isinstance(obj, dict):
# Sort items to ensure consistent ordering
return tuple(sorted((k, FrozenDict._recursive_freeze(v)) for k, v in obj.items()))
elif isinstance(obj, (list, tuple)):
return tuple(FrozenDict._recursive_freeze(v) for v in obj)
elif isinstance(obj, set):
return frozenset(FrozenDict._recursive_freeze(v) for v in obj)
# Add more types here if needed (e.g., numpy arrays)
# If the object is already hashable, return as is
try:
hash(obj)
except TypeError:
raise TypeError(f"Unhashable type encountered in FrozenDict: {type(obj)}")
return obj

Copilot uses AI. Check for mistakes.
Comment on lines +229 to +244
def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]:
seq_length = x.shape[1]
position = jnp.arange(seq_length, dtype=jnp.float32)[
jnp.newaxis, :, jnp.newaxis, jnp.newaxis
]
sinusoid_inp = position / s["timescale"]
sin = jnp.sin(sinusoid_inp).astype(x.dtype)
cos = jnp.cos(sinusoid_inp).astype(x.dtype)
first_half, second_half = jnp.split(x, 2, axis=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
if self.cast_as_fprop_dtype:
first_part = first_part.astype(self.fprop_dtype)
second_part = second_part.astype(self.fprop_dtype)
x_out = jnp.concatenate((first_part, second_part), axis=-1)
return x_out, s
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The RoPE (Rotary Position Embedding) timescale computation in the state() method is performed once during initialization and stored. However, the forward() method recomputes the position array and sinusoids for every forward pass, even though these only depend on seq_length. For sequences of the same length (which is common in training), consider caching these computed values indexed by sequence length to avoid redundant computation.

Copilot uses AI. Check for mistakes.
from jax.experimental import mesh_utils


def identity(x):
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The identity function lacks a docstring explaining its purpose. While the implementation is trivial, documentation would clarify its intended use case (e.g., as a placeholder layer in the Parallel composition as seen in mini_transformer.py line 110).

Suggested change
def identity(x):
def identity(x):
"""
Returns the input unchanged.
This function can be used as a placeholder layer, for example in model compositions
such as the `Parallel` composition, where a no-op function is required.
"""

Copilot uses AI. Check for mistakes.
# TODO: cast dtype
return x * p["w"] + p["b"], s
x_std = jax.nn.standardize(
x.astype(self.compute_dtype), epsilon=self.epsilon
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The x.astype(self.compute_dtype) call will fail when compute_dtype is None (the default value). This should check if compute_dtype is not None before casting, or use a conditional expression like x if self.compute_dtype is None else x.astype(self.compute_dtype).

Suggested change
x.astype(self.compute_dtype), epsilon=self.epsilon
x if self.compute_dtype is None else x.astype(self.compute_dtype), epsilon=self.epsilon

Copilot uses AI. Check for mistakes.
Comment on lines +85 to +86
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __init__ method added to Chain, Branch, and Parallel classes (lines 71-72, 85-86, 100-101) only calls super().__init__(*args, **kwargs) without any additional logic. Since these methods don't add any behavior, they are redundant and can be removed. The parent class NamedLayers already has an __init__ method that will be called automatically.

Copilot uses AI. Check for mistakes.
Comment on lines 40 to +49
def save(self, p: Param, s: State):
self.checkpoint_manager.save(
s["trainer"]["step"],
args=ocp.args.Composite(
param=ocp.args.PyTreeSave(item=p),
state_trainer=ocp.args.PyTreeSave(item=s["trainer"]),
state_dataset_iter=grain.checkpoint.CheckpointSave(item=s["input"]),
),
)
if self.checkpoint_manager:
self.checkpoint_manager.save(
s["step"],
args=ocp.args.Composite(
param=ocp.args.PyTreeSave(item=p),
state_trainer=ocp.args.PyTreeSave(item=s["trainer"]),
state_dataset_iter=grain.checkpoint.CheckpointSave(item=s["input"]),
),
)
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The save() method doesn't include s["step"] in the checkpoint composite, and the restore() method on line 77 (not in changed region) doesn't include step in the returned State. This will cause a KeyError when observers try to access s["step"] after restoration. The step should be either: (1) saved in the composite as state_step=ocp.args.PyTreeSave(item=s["step"]) and restored, or (2) inferred from the checkpoint step number returned by the restore operation.

Copilot uses AI. Check for mistakes.
w_rng,
(self.dim,),
dtype=self.param_dtype,
out_sharding=self.out_sharding,
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The w parameter initialization uses out_sharding=self.out_sharding (line 266), but should use out_sharding=self.param_sharding to be consistent with other layers like Linear (line 129) and with the b parameter initialization in this same method (line 272). The out_sharding field should be used for operation outputs, not parameter initialization.

Suggested change
out_sharding=self.out_sharding,
out_sharding=(
None if self.param_sharding is None else P(self.param_sharding[-1])
),

Copilot uses AI. Check for mistakes.
Comment on lines +71 to +72
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __init__ method added to Chain, Branch, and Parallel classes (lines 71-72, 85-86, 100-101) only calls super().__init__(*args, **kwargs) without any additional logic. Since these methods don't add any behavior, they are redundant and can be removed. The parent class NamedLayers already has an __init__ method that will be called automatically.

Copilot uses AI. Check for mistakes.
Comment on lines +100 to +101
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __init__ method added to Chain, Branch, and Parallel classes (lines 71-72, 85-86, 100-101) only calls super().__init__(*args, **kwargs) without any additional logic. Since these methods don't add any behavior, they are redundant and can be removed. The parent class NamedLayers already has an __init__ method that will be called automatically.

Copilot uses AI. Check for mistakes.
@findmyway
Copy link
Member Author

@copilot Merge it

Copy link
Contributor

Copilot AI commented Nov 28, 2025

@findmyway I've opened a new pull request, #5, to work on those changes. Once the pull request is ready, I'll request review from you.

Copilot AI review requested due to automatic review settings November 28, 2025 14:19
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 12 changed files in this pull request and generated 7 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +280 to +283
x_std = jax.nn.standardize(
x.astype(self.compute_dtype), epsilon=self.epsilon
).astype(self.param_dtype)
return x_std * p["w"] + p["b"], s
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype handling in LayerNorm has issues:

  1. When self.compute_dtype is None, .astype(None) will raise a TypeError
  2. The output is cast to self.param_dtype which is likely incorrect - the output should match the input dtype or remain in compute dtype

Consider this fix:

compute_dtype = self.compute_dtype if self.compute_dtype is not None else x.dtype
x_std = jax.nn.standardize(x.astype(compute_dtype), epsilon=self.epsilon)
return (x_std * p["w"] + p["b"]).astype(x.dtype), s
Suggested change
x_std = jax.nn.standardize(
x.astype(self.compute_dtype), epsilon=self.epsilon
).astype(self.param_dtype)
return x_std * p["w"] + p["b"], s
compute_dtype = self.compute_dtype if self.compute_dtype is not None else x.dtype
x_std = jax.nn.standardize(
x.astype(compute_dtype), epsilon=self.epsilon
)
return (x_std * p["w"] + p["b"]).astype(x.dtype), s

Copilot uses AI. Check for mistakes.
w_rng,
(self.dim,),
dtype=self.param_dtype,
out_sharding=self.out_sharding,
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect sharding parameter. The w_init call should use out_sharding=self.param_sharding instead of out_sharding=self.out_sharding. The pattern used in other layers (e.g., Linear at lines 125-130, Embedding at lines 197-202) shows that parameter initialization should use param_sharding, while out_sharding is used for computation outputs.

Suggested change
out_sharding=self.out_sharding,
out_sharding=self.param_sharding,

Copilot uses AI. Check for mistakes.
return x_out, s


class Unembedding(Embedding):
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The Unembedding class inherits from Embedding but uses a completely different forward operation (matrix multiplication vs. embedding lookup). This inheritance relationship is semantically incorrect - the two operations are fundamentally different.

Consider making Unembedding a standalone class that doesn't inherit from Embedding, or use composition instead of inheritance. The current design violates the Liskov Substitution Principle since an Unembedding cannot be used interchangeably with an Embedding.

Suggested change
class Unembedding(Embedding):
class Unembedding(LayerBase):
in_dim: int
out_dim: int
w_init: Initializer = variance_scaling(1.0, "fan_in", "normal", out_axis=0)
def param(self, rng: PRNG) -> Param:
return Param(
w=self.w_init(
rng,
(self.in_dim, self.out_dim),
dtype=self.param_dtype,
out_sharding=self.param_sharding,
)
)

Copilot uses AI. Check for mistakes.
Comment on lines +32 to +39
class SkipConnection(LayerBase):
layer: LayerLike
connection: Callable = jnp.add

def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]:
S = State()
o, S["layer"] = self.layer(x, p["layer"], s["layer"])
return self.connection(o, x), S
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new SkipConnection layer lacks test coverage. Consider adding tests to verify:

  • The skip connection correctly adds the input to the layer output
  • The state is properly propagated
  • Custom connection functions work as expected

Copilot uses AI. Check for mistakes.
Comment on lines +42 to +54
class Repeated(LayerBase):
n: int
layer: LayerLike

def sublayers(self) -> dict:
return {f"layer_{i}": self.layer for i in range(self.n)}

def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]:
S = State()
o = x
for i in range(self.n):
o, S[f"layer_{i}"] = self.layer(o, p[f"layer_{i}"], s[f"layer_{i}"])
return o, S
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new Repeated layer lacks test coverage. Consider adding tests to verify:

  • The layer is correctly repeated n times
  • Parameters and states are properly namespaced (layer_0, layer_1, etc.)
  • The output of one iteration correctly feeds into the next

Copilot uses AI. Check for mistakes.
Comment on lines +209 to +245
class RotaryEmbedding(LayerBase):
"""Rotary Position Embedding."""

# Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/9204d6bbbf8bb19a05ebed72a55cfec687e0e044/src/MaxText/layers/embeddings.py#L271C11-L356C17
embedding_dims: int
min_timescale: int = 1
max_timescale: int = 10000
cast_as_fprop_dtype: bool = True
fprop_dtype: Dtype = jnp.bfloat16
rope_linear_scaling_factor: float = 1.0

def state(self, rng: PRNG) -> State:
half_embedding_dim = self.embedding_dims // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims
timescale = (
self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
)
if self.rope_linear_scaling_factor != 1.0:
timescale = timescale * self.rope_linear_scaling_factor
return State(timescale=timescale)

def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]:
seq_length = x.shape[1]
position = jnp.arange(seq_length, dtype=jnp.float32)[
jnp.newaxis, :, jnp.newaxis, jnp.newaxis
]
sinusoid_inp = position / s["timescale"]
sin = jnp.sin(sinusoid_inp).astype(x.dtype)
cos = jnp.cos(sinusoid_inp).astype(x.dtype)
first_half, second_half = jnp.split(x, 2, axis=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
if self.cast_as_fprop_dtype:
first_part = first_part.astype(self.fprop_dtype)
second_part = second_part.astype(self.fprop_dtype)
x_out = jnp.concatenate((first_part, second_part), axis=-1)
return x_out, s
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new RotaryEmbedding layer lacks test coverage. Consider adding tests to verify:

  • The rotary position embeddings are computed correctly
  • The timescale calculation matches the expected behavior
  • The linear scaling factor works as intended
  • The output shape matches the input shape

Copilot uses AI. Check for mistakes.
Comment on lines +248 to +250
class Unembedding(Embedding):
def forward(self, x: Array, p: Param, s: State) -> tuple[Array, State]:
return p["w"][x], s
return jnp.einsum("bld,dn->bln", x, p["w"], out_sharding=self.out_sharding), s
Copy link

Copilot AI Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new Unembedding layer lacks test coverage. Consider adding tests to verify:

  • The matrix multiplication with the embedding weights is computed correctly
  • The output shape is correct (batch, length, vocab_size)
  • Sharding parameters work as expected

Copilot uses AI. Check for mistakes.
@findmyway findmyway merged commit 2a9e5a7 into main Nov 28, 2025
@findmyway findmyway deleted the dev branch November 28, 2025 14:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants