Skip to content

feat: Matryoshka auxiliary loss#643

Merged
chanind merged 8 commits intomainfrom
matryoshka-aux-loss
Feb 13, 2026
Merged

feat: Matryoshka auxiliary loss#643
chanind merged 8 commits intomainfrom
matryoshka-aux-loss

Conversation

@chanind
Copy link
Collaborator

@chanind chanind commented Feb 10, 2026

This PR adds an optional new auxiliary loss for Matryoshka SAEs, where the auxiliary loss tries to revive dead features into the residual of the level they are at, rather than into the residual of the SAE as a whole.

@claude
Copy link

claude bot commented Feb 10, 2026

PR Review: Matryoshka Auxiliary Loss

This PR adds an optional auxiliary loss for Matryoshka SAEs that encourages dead features to reconstruct the residual of their own level rather than the entire SAE. The implementation is well-structured with comprehensive test coverage. Here's my feedback:

Code Quality & Best Practices

Positive aspects:

  • Clean separation of concerns with the new _iterable_decode generator method
  • Good reuse of existing calculate_topk_aux_acts function (now properly exported from topk_sae.py)
  • Proper use of @override decorator
  • Configuration is backward-compatible with use_matryoshka_aux_loss defaulting to False
  • Documentation is clear about the tradeoff (better revival vs slower training)

Issues to address:

  1. Potential gradient flow issue (matryoshka_batchtopk_sae.py:175): The residual calculation uses .detach():

    residual = (sae_in - partial_sae_out).detach()

    This prevents gradients from flowing back through partial_sae_out to earlier levels. While this matches the base class behavior in topk_sae.py:458, it means the auxiliary loss won't influence the reconstruction of earlier levels within the same matryoshka level. This may be intentional, but worth documenting why.

  2. Inefficient recomputation (matryoshka_batchtopk_sae.py:161): When use_matryoshka_aux_loss=True, the method recomputes acts = self.activation_fn(hidden_pre) even though these activations were already computed in the forward pass. Consider passing them as a parameter to avoid redundant computation.

  3. Memory inefficiency in aux loss (matryoshka_batchtopk_sae.py:187): The reconstruction adds the full b_dec bias for each level:

    recons = auxk_acts @ self.W_dec[prev_width:width] + self.b_dec

    However, b_dec is already included in partial_sae_out (since it's computed via _iterable_decode). This means the bias is being added multiple times unnecessarily. This should likely just be:

    recons = auxk_acts @ self.W_dec[prev_width:width]

    without the bias term, since we're computing the residual against partial_sae_out which already includes it.

  4. Type annotation cleanup (tests/helpers.py:578): The change from TrainingSAEConfigDict to dict[str, Any] is correct (since matryoshka config extends it), but the comment and variable names could be clearer about why this changed.

Test Coverage

Excellent test coverage:

  • Manual computation verification ✓
  • Tests that dead features target their level's residual (not full SAE residual) ✓
  • Gradient flow validation ✓
  • Flag toggle behavior ✓
  • Edge cases (no dead features, selective dead features, skipped levels) ✓

Test quality issues:

  1. test_matryoshka_aux_loss_manual_computation (line 421): This test is excellent for verifying correctness, but if the bias addition issue (point 3 above) is fixed, this test will need updating to match.

  2. test_matryoshka_aux_loss_gradients_flow_through_own_level_weights (line 149): The test verifies that gradients only flow to the appropriate level, which is good. However, it uses pytest.approx(0.0, abs=1e-7) which is overly strict and could cause flakiness. Consider using a slightly looser tolerance like abs=1e-6.

  3. All tests use rescale_acts_by_decoder_norm=False: Consider adding at least one test with rescale_acts_by_decoder_norm=True to ensure the new code path handles the normalization correctly in _iterable_decode.

Performance Considerations

  1. Training overhead: The _iterable_decode method is called multiple times per training step when both training_forward_pass and calculate_topk_aux_loss are invoked with use_matryoshka_aux_loss=True. The forward pass iterates once (line 110-120), and aux loss iterates again (line 164-190). Consider caching partial reconstructions if this becomes a bottleneck.

  2. Sparse tensor handling: Good use of sparse matmul when appropriate (line 96-100).

Security & Correctness

Critical issue:

The bias term addition in the aux loss reconstruction (point 3 above) appears incorrect and will lead to wrong gradient signals for dead feature revival. This needs verification and likely fixing.

Minor issues:

  1. Edge case: If matryoshka_widths has only one level (which triggers a warning), calculate_topk_aux_loss with use_matryoshka_aux_loss=True will still iterate over all levels. This is correct but could have a comment explaining the behavior.

Recommendations

Must fix:

  • Verify and fix the bias term addition in aux loss (item 3 in Code Quality section)

Should fix:

  • Pass acts to avoid recomputation in aux loss
  • Document why residual uses .detach()
  • Add test with rescale_acts_by_decoder_norm=True

Nice to have:

  • Consider caching partial reconstructions if performance becomes an issue
  • Loosen the gradient test tolerance slightly

Summary

This is a solid implementation with thorough testing. The main concern is the potential bias term error in the auxiliary loss reconstruction, which could affect the correctness of dead feature revival. Once that's verified/fixed, this should be ready to merge.

@codecov
Copy link

codecov bot commented Feb 10, 2026

Codecov Report

❌ Patch coverage is 96.29630% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
sae_lens/saes/matryoshka_batchtopk_sae.py 96.00% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Copilot AI review requested due to automatic review settings February 12, 2026 16:26
@chanind chanind force-pushed the matryoshka-aux-loss branch from 4e3eeaa to 424e190 Compare February 12, 2026 16:30
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an optional Matryoshka-specific auxiliary loss for BatchTopK SAEs, where dead features are trained to reconstruct the level-local residual rather than the full SAE residual.

Changes:

  • Introduces use_matryoshka_aux_loss config flag and Matryoshka-specific aux-loss implementation.
  • Exposes calculate_topk_aux_acts as a public helper and updates call sites.
  • Adds unit tests covering manual computation, residual targeting behavior, gradient flow, and flag-off parity.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 9 comments.

File Description
sae_lens/saes/matryoshka_batchtopk_sae.py Adds iterable level decoding + new Matryoshka aux loss behind a config flag.
sae_lens/saes/topk_sae.py Renames/exports aux-acts helper to reuse from Matryoshka implementation.
tests/saes/test_matryoshka_batchtopk_sae.py Adds tests validating new aux-loss behavior + gradients + parity.
tests/helpers.py Updates helper config typing and adds default for new flag.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

)

# Reduce the scale of the loss if there are a small number of dead latents
scale = min(partial_num_dead / partial_k_aux, 1.0)
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scaling factor is effectively always 1.0 because partial_k_aux = min(k_aux, partial_num_dead) makes partial_num_dead / partial_k_aux >= 1. In calculate_topk_aux_loss the scale is computed before clamping k_aux to num_dead. To match that intended behavior (and your tests' stated intent), compute scale using the unclamped k_aux (or compute scale before setting partial_k_aux).

Suggested change
scale = min(partial_num_dead / partial_k_aux, 1.0)
scale = min(partial_num_dead / k_aux, 1.0)

Copilot uses AI. Check for mistakes.
Comment on lines 208 to 209
# top k living latents
recons = auxk_acts @ self.W_dec[prev_width:width]
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The aux-loss reconstruction path is inconsistent with the Matryoshka decode path used to compute partial_sae_out (via _iterable_decode). _iterable_decode may rescale activations by decoder norm and applies run_time_activation_norm_fn_out (and starts from b_dec), but recons here does not. This makes the target residual and the predicted reconstruction live in different 'decode conventions', and also ignores rescale_acts_by_decoder_norm for the aux path. Consider constructing recons using the same conventions as _iterable_decode for a single level (including any act rescaling and output activation-norm steps, and whether b_dec should be included) so that (recons - residual) is meaningful and consistent across configs.

Suggested change
# top k living latents
recons = auxk_acts @ self.W_dec[prev_width:width]
# top k living latents. Use the same decode conventions as the main
# Matryoshka path (activation rescaling, bias, and output norm).
scaled_auxk_acts = auxk_acts
if getattr(self.cfg, "rescale_acts_by_decoder_norm", False) and hasattr(
self, "decoder_norms"
):
scaled_auxk_acts = scaled_auxk_acts * self.decoder_norms[
prev_width:width
]
recons = scaled_auxk_acts @ self.W_dec[prev_width:width]
if hasattr(self, "b_dec") and self.b_dec is not None:
recons = recons + self.b_dec
if hasattr(self, "run_time_activation_norm_fn_out") and (
self.run_time_activation_norm_fn_out is not None
):
recons = self.run_time_activation_norm_fn_out(recons)

Copilot uses AI. Check for mistakes.
Comment on lines 85 to 86
inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
if self.cfg.rescale_acts_by_decoder_norm:
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inv_W_dec_norm is computed unconditionally even when rescale_acts_by_decoder_norm is False. This adds an extra norm+division on every call, and can produce inf when rows of W_dec are zero (your tests set W_dec[0:3] = 0.0). Compute inv_W_dec_norm only inside the if, and consider clamping the norm with an epsilon to avoid inf when rescaling is enabled.

Suggested change
inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
if self.cfg.rescale_acts_by_decoder_norm:
if self.cfg.rescale_acts_by_decoder_norm:
# Clamp norm to avoid division by zero when some rows of W_dec are zero.
W_dec_norm = self.W_dec.norm(dim=-1)
inv_W_dec_norm = 1.0 / torch.clamp(W_dec_norm, min=1e-12)

Copilot uses AI. Check for mistakes.
Comment on lines +521 to 524
def calculate_topk_aux_acts(
k_aux: int,
hidden_pre: torch.Tensor,
dead_neuron_mask: torch.Tensor,
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renaming _calculate_topk_aux_acts to calculate_topk_aux_acts is a breaking change for any internal/external callers that still import the underscored name. To preserve compatibility, consider keeping _calculate_topk_aux_acts as a thin alias to calculate_topk_aux_acts (optionally with a deprecation warning) for at least one release cycle.

Copilot uses AI. Check for mistakes.
partial_dead_mask_1 = dead_neuron_mask[0:3]
num_dead_1 = int(partial_dead_mask_1.sum()) # = 1
partial_k_aux_1 = min(k_aux, num_dead_1) # = 1
scale_1 = min(num_dead_1 / partial_k_aux_1, 1.0) # = 1.0
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test claims it exercises scale < 1.0 (comment says 0.5), but the computed scale_1 is 1.0 due to using partial_k_aux_1 (already clamped). If the intended behavior is the same as calculate_topk_aux_loss (scale computed vs the original k_aux), then the test should compute scale_1 against k_aux (yielding 0.5 here) so it actually validates the scaling behavior.

Suggested change
scale_1 = min(num_dead_1 / partial_k_aux_1, 1.0) # = 1.0
scale_1 = min(num_dead_1 / k_aux, 1.0) # = 0.5

Copilot uses AI. Check for mistakes.
k_aux = d_in // 2 # = 2

# Level 1: features 0:3, 1 dead feature
partial_sae_out_1 = acts[:, 0:3] @ sae.W_dec.data[0:3] + sae.b_dec.data
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using .data in tests bypasses autograd and can mask subtle issues if these tensors ever end up participating in graph construction. Prefer .detach() (or just use the parameter tensors directly when gradients are irrelevant) to make the intent explicit without relying on .data.

Copilot uses AI. Check for mistakes.
auxk_acts_1 = calculate_topk_aux_acts(
partial_k_aux_1, partial_hidden_pre_1, partial_dead_mask_1
)
recons_1 = auxk_acts_1 @ sae.W_dec.data[0:3]
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using .data in tests bypasses autograd and can mask subtle issues if these tensors ever end up participating in graph construction. Prefer .detach() (or just use the parameter tensors directly when gradients are irrelevant) to make the intent explicit without relying on .data.

Copilot uses AI. Check for mistakes.
Comment on lines +471 to +473
acts[:, 0:3] @ sae.W_dec.data[0:3]
+ acts[:, 3:6] @ sae.W_dec.data[3:6]
+ sae.b_dec.data
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using .data in tests bypasses autograd and can mask subtle issues if these tensors ever end up participating in graph construction. Prefer .detach() (or just use the parameter tensors directly when gradients are irrelevant) to make the intent explicit without relying on .data.

Copilot uses AI. Check for mistakes.
auxk_acts_2 = calculate_topk_aux_acts(
partial_k_aux_2, partial_hidden_pre_2, partial_dead_mask_2
)
recons_2 = auxk_acts_2 @ sae.W_dec.data[3:6]
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using .data in tests bypasses autograd and can mask subtle issues if these tensors ever end up participating in graph construction. Prefer .detach() (or just use the parameter tensors directly when gradients are irrelevant) to make the intent explicit without relying on .data.

Copilot uses AI. Check for mistakes.
@claude
Copy link

claude bot commented Feb 12, 2026

PR Review: Matryoshka Auxiliary Loss

This PR adds an optional auxiliary loss mechanism for Matryoshka SAEs where dead features are encouraged to reconstruct the residual of their own level rather than the entire SAE. Overall, the implementation looks solid with comprehensive tests. Here are my findings:

Strengths

  1. Excellent test coverage: The test suite is comprehensive and tests the actual behavior rather than superficial properties, following the repository guidelines perfectly.

  2. Well-structured tests: Tests verify core functionality including:

    • Manual computation verification
    • Level-specific residual usage (not full residual)
    • Gradient flow correctness
    • Proper handling of dead features at different levels
    • Compatibility with rescale_acts_by_decoder_norm
  3. Good documentation: The config parameter is well-documented explaining the tradeoff (better feature revival vs slower training).

  4. Backward compatibility: The feature is off by default and doesn't break existing behavior.

Code Quality Issues

1. Potential Performance Issue in calculate_matryoshka_aux_loss (matryoshka_batchtopk_sae.py:197-199)

The TODO comment is accurate - this recalculates SAE output for each level:

# TODO: find a way to implement this without needing to recalculate the SAE output for each level
# may need to wait for a refactor in the next release of sae_lens for a clean way to do this
for width, partial_sae_out in self._iterable_decode(
    feature_acts, include_outer_loss=True
):

However, _iterable_decode is already accumulative (line 107: decoded = inner_feature_acts @ self.W_dec[prev_width:width] + decoded), so the partial outputs are already being computed incrementally. The issue is that you're calling _iterable_decode twice: once in training_forward_pass (line 114) and again in calculate_matryoshka_aux_loss (line 197).

Suggestion: Consider caching the partial reconstructions from training_forward_pass to avoid redundant computation in the aux loss, or refactor to compute them once and pass them through.

2. Inconsistent Scaling Logic Between Base and Matryoshka Aux Loss

In calculate_matryoshka_aux_loss (line 216):

scale = min(partial_num_dead / partial_k_aux, 1.0)

But the base class calculate_topk_aux_loss (topk_sae.py:488) does:

scale = min(num_dead / k_aux, 1.0)
k_aux = min(k_aux, num_dead)

The order is different. In the base class, k_aux is clamped AFTER computing scale. In the matryoshka version (line 206):

partial_k_aux = min(k_aux, partial_num_dead)

This happens BEFORE computing scale (line 216). This means when partial_num_dead < k_aux, the matryoshka version will always have scale = 1.0, while the base version would have scale < 1.0.

This is a bug. The scale should be computed before clamping k_aux. The test at line 444-448 tries to exercise scale < 1.0 but won't catch this because it sets up the scenario incorrectly (only 1 dead feature, but expects scale = 0.5 in the comment).

Fix: Swap lines 206 and 216:

scale = min(partial_num_dead / k_aux, 1.0)
partial_k_aux = min(k_aux, partial_num_dead)

3. Test Assertion May Be Too Loose (test_matryoshka_batchtopk_sae.py:559)

assert matryoshka_loss.item() != pytest.approx(base_loss.item(), rel=0.1)

This test uses != with pytest.approx, which is checking if values are NOT approximately equal. The intent is good, but the implementation is checking if the difference is > 10%. This could pass even if the difference is small (e.g., 11%).

Suggestion: Be more explicit about the expected relationship:

assert matryoshka_loss.item() > base_loss.item() * 1.5  # or some meaningful threshold

Or add a comment explaining why 10% relative difference is the right threshold.

4. Missing Edge Case: What if aux_losses is Empty? (matryoshka_batchtopk_sae.py:228)

If all levels have 0 dead features (every iteration hits line 203 continue), then aux_losses will be empty and line 228 will fail:

stacked_losses = torch.stack(aux_losses)

This would raise RuntimeError: stack expects a non-empty TensorList. The early return at line 184 only handles the case where dead_neuron_mask is None or has sum 0, but not the case where dead features exist but are all in ranges that get skipped.

Fix: Check if aux_losses is empty before stacking:

if not aux_losses:
    return sae_out.new_tensor(0.0)
stacked_losses = torch.stack(aux_losses)

Minor Issues

  1. Unused import in tests (test_matryoshka_batchtopk_sae.py:1)
import copy

Only used once. This is fine, but worth noting.

  1. Type annotation improvement in helpers.py (line 578)
    Good fix changing from TrainingSAEConfigDict to dict[str, Any], but this appears to be a cleanup rather than a functional change related to the feature.

Testing Suggestions

  1. Missing test case: Test with normalize_activations enabled. The base calculate_topk_aux_loss raises an error if normalization is used (topk_sae.py:474-480). Does the matryoshka version handle this correctly? Should add a test.

  2. Statistical rigor: The gradient tests check for exact zero (line 578, 582, 585, 587) using pytest.approx(0.0, abs=1e-7). Consider if this is tight enough or if numerical precision could cause spurious failures. Given the CLAUDE.md guidance about tight bounds, this seems reasonable.

Summary

This is a well-thought-out feature with excellent tests. The main issues are:

  1. Critical: The scaling logic bug (item 2 above) - the scale is computed after clamping k_aux instead of before
  2. Important: Missing empty list check (item 4 above) - could cause runtime error
  3. Performance: Redundant computation noted in TODO (item 1 above) - acknowledged but worth addressing

The tests are exemplary in following the repository guidelines - they test actual correctness rather than superficial properties. Great work overall!

Recommendation: Fix items 2 and 4 before merging. Item 1 can be addressed in a future PR as noted in the TODO.

@decoderesearch decoderesearch deleted a comment from claude bot Feb 12, 2026
@claude
Copy link

claude bot commented Feb 12, 2026

Code Review: Matryoshka Auxiliary Loss

Thanks for this PR! This adds an interesting new capability for Matryoshka SAEs. I've reviewed the changes and have the following feedback:

Strengths

  1. Well-tested: The test suite is comprehensive and tests the core logic thoroughly. Tests verify:

    • Manual computation matches the implementation
    • Level-specific residuals are used (not just full SAE residual)
    • Gradients flow through the correct level weights
    • The feature correctly falls back to base class behavior when disabled
    • Statistical properties with level skipping
  2. Clean implementation: The _iterable_decode generator pattern is elegant and makes the code reusable between training_forward_pass and calculate_matryoshka_aux_loss.

  3. Good documentation: The docstrings clearly explain the feature and tradeoffs.

  4. Backward compatibility: The use_matryoshka_aux_loss flag defaults to False, ensuring existing code continues to work.

Issues to Address

1. Performance concern with _iterable_decode (matryoshka_batchtopk_sae.py:86-108)

The _iterable_decode method modifies feature_acts in place when rescale_acts_by_decoder_norm=True:

if self.cfg.rescale_acts_by_decoder_norm:
    inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
    feature_acts = feature_acts * inv_W_dec_norm  # <-- modifies input tensor

This is called from both training_forward_pass and calculate_matryoshka_aux_loss, both of which pass base_output.feature_acts. This means:

  • In training_forward_pass, the original feature_acts gets rescaled
  • Then in calculate_aux_loss, those already-rescaled activations get rescaled again

Fix: Either clone the tensor before modifying it, or use a local variable:

def _iterable_decode(
    self, feature_acts: torch.Tensor, include_outer_loss: bool = False
) -> Generator[tuple[int, torch.Tensor], None, None]:
    if self.cfg.rescale_acts_by_decoder_norm:
        inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
        feature_acts = feature_acts * inv_W_dec_norm  # Creates new tensor due to broadcasting
    # ... rest of method

Actually, looking more carefully, the multiplication with broadcasting creates a new tensor, so this might be fine. But to be safe and clear, consider explicitly cloning or documenting this behavior.

2. Potential gradient issue in calculate_matryoshka_aux_loss (matryoshka_batchtopk_sae.py:210)

The residual computation uses .detach():

residual = (sae_in - partial_sae_out).detach()

This is correct for the standard aux loss (matching the base implementation). However, I notice that partial_sae_out is computed from _iterable_decode, which itself uses feature_acts that still has gradients. Since we .detach() the residual, gradients won't flow back through the reconstructed output to update the living features' weights—only through the aux reconstruction term. This seems intentional and matches the base class, so this is fine.

3. Empty losses list could cause error (matryoshka_batchtopk_sae.py:227)

If no levels have any dead features, aux_losses will be empty, and torch.stack(aux_losses) will fail:

aux_losses = []
for width, partial_sae_out in self._iterable_decode(...):
    # ...
    if partial_num_dead == 0:
        prev_width = width
        continue
    # ... append to aux_losses
stacked_losses = torch.stack(aux_losses)  # <-- fails if aux_losses is empty

Fix: Add a check before stacking:

if len(aux_losses) == 0:
    return sae_out.new_tensor(0.0)
stacked_losses = torch.stack(aux_losses)
return self.cfg.aux_loss_coefficient * stacked_losses.sum()

4. Type annotation inconsistency (tests/helpers.py:578)

The change from TrainingSAEConfigDict to dict[str, Any] is correct, but it's inconsistent with the pattern used in other builder functions. Looking at the code, this seems fine since you're adding a field (use_matryoshka_aux_loss) that's not in TrainingSAEConfigDict. This is actually a fix, not an issue.

Minor Suggestions

  1. Comment clarity (matryoshka_batchtopk_sae.py:189-194): The comment about rescaling could be more precise. Currently it says "just rescale the decoder weights once", but what it really does is normalize them once. Consider:

    # Normalize decoder weights once to avoid repeated computation across levels
  2. TODO comment (matryoshka_batchtopk_sae.py:195-196): The TODO mentions waiting for a refactor. Is this causing a significant performance issue? If so, could you quantify it in the comment? If not, consider if the TODO is necessary.

  3. Test naming: The test names are excellent and self-documenting. Good job following the CLAUDE.md guidelines.

  4. Unused method (matryoshka_batchtopk_sae.py:126-147): The _decode_matryoshka_level method appears to be unused after refactoring to use _iterable_decode. Consider removing it.

Performance Considerations

The new auxiliary loss requires calling _iterable_decode which recomputes the SAE output for each level. As noted in the TODO comment, this could be optimized by caching intermediate results. However, this only happens when use_matryoshka_aux_loss=True, which defaults to False, so it's an opt-in performance cost that's well-documented.

Test Coverage

The test coverage is excellent. The tests focus on verifying correct behavior rather than superficial checks, which aligns perfectly with the CLAUDE.md guidelines. In particular:

  • test_matryoshka_aux_loss_manual_computation: Validates the math by hand
  • test_matryoshka_aux_loss_uses_level_residual_not_full_residual: Tests the key behavioral difference
  • test_matryoshka_aux_loss_gradients_flow_through_own_level_weights: Verifies gradient flow correctness

Summary

This is a solid implementation with comprehensive tests. The main issue to fix is the potential empty list in torch.stack(). Consider also reviewing whether _decode_matryoshka_level should be removed and clarifying the tensor mutation behavior in _iterable_decode.

The feature design makes sense—encouraging dead features to revive at their own level rather than the full SAE output is a sensible architectural choice for Matryoshka SAEs.

@claude
Copy link

claude bot commented Feb 12, 2026

Pull Request Review: Matryoshka Auxiliary Loss

Thank you for this contribution! I've reviewed the PR and have the following feedback:

Strengths

  1. Well-tested feature: The test coverage is excellent with 8 comprehensive tests covering edge cases, gradient flow, manual computation verification, and comparison with base class behavior.

  2. Clear documentation: The docstring additions explain the new use_matryoshka_aux_loss parameter well.

  3. Good code organization: The _iterable_decode helper method is a clean abstraction that's reused effectively.

  4. Thoughtful design: The per-level auxiliary loss approach is conceptually sound - encouraging dead features to reconstruct their own level's residual should indeed improve feature revival.

Issues & Suggestions

Critical Issues

  1. Potential bug in _iterable_decode with rescale_acts_by_decoder_norm (matryoshka_batchtopk_sae.py:89-92)

    The method modifies feature_acts in-place when rescaling:

    if self.cfg.rescale_acts_by_decoder_norm:
        inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
        feature_acts = feature_acts * inv_W_dec_norm

    This is problematic because:

    • If feature_acts is reused by the caller after this method returns, it will have been mutated
    • This breaks functional programming principles and can lead to subtle bugs

    Fix: Clone before modifying:

    if self.cfg.rescale_acts_by_decoder_norm:
        inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
        feature_acts = feature_acts * inv_W_dec_norm  # This creates a new tensor, but should be explicit

    Actually, multiplication creates a new tensor, so this may not be a bug. But to be safe and explicit, consider feature_acts = feature_acts.clone() * inv_W_dec_norm if you want to be certain.

  2. Empty aux_losses list causes error (matryoshka_batchtopk_sae.py:204)

    In calculate_matryoshka_aux_loss, if no level has dead features, aux_losses will be empty and torch.stack(aux_losses) will fail.

    Current code:

    if dead_neuron_mask is not None and int(dead_neuron_mask.sum()) > 0:
        # ... loop that may skip all levels if partial_num_dead == 0 for each
        stacked_losses = torch.stack(aux_losses)  # FAILS if aux_losses is empty
        return self.cfg.aux_loss_coefficient * stacked_losses.sum()
    return sae_out.new_tensor(0.0)

    Fix: Check if aux_losses is empty before stacking:

    if aux_losses:
        stacked_losses = torch.stack(aux_losses)
        return self.cfg.aux_loss_coefficient * stacked_losses.sum()
    return sae_out.new_tensor(0.0)

Performance Concerns

  1. Redundant computation in calculate_matryoshka_aux_loss (matryoshka_batchtopk_sae.py:174-176)

    The TODO comment on line 172-173 mentions this, but it's worth emphasizing: calling _iterable_decode with include_outer_loss=True recalculates all SAE outputs for each level, which is expensive. The partial SAE outputs are already computed in training_forward_pass.

    Suggestion: Consider caching these intermediate reconstructions during the forward pass, or restructure to avoid recomputation. This is a significant performance concern for training speed.

  2. Repeated norm computation (matryoshka_batchtopk_sae.py:91, 167-170)

    W_dec.norm(dim=-1) is computed in both _iterable_decode and calculate_matryoshka_aux_loss. If both are called in the same forward pass (which they are), this is wasteful.

    Suggestion: Compute once and reuse, or cache as a buffer if this norm doesn't change during training.

Code Quality

  1. Inconsistent return types in _iterable_decode (matryoshka_batchtopk_sae.py:88)

    The method yields (int, torch.Tensor) tuples, but the width is the cumulative width, not the level index. This is confusing. Consider yielding (level_idx, width, reconstruction) or making the semantics clearer in documentation.

  2. Type annotation could be more specific (matryoshka_batchtopk_sae.py:158)

    dead_neuron_mask: torch.Tensor | None - but the logic assumes it's a boolean tensor if not None. Consider using a more specific type annotation or adding a runtime check.

Test Quality

  1. Great tests overall, but one suggestion:

    In test_matryoshka_aux_loss_gradients_flow_through_own_level_weights (test_matryoshka_batchtopk_sae.py:142-191), the test uses pytest.approx(0.0, abs=1e-7) to verify zero gradients. This is correct per the guidelines. Good job following the conventions!

  2. Test test_matryoshka_aux_loss_manual_computation (test_matryoshka_batchtopk_sae.py:419-494)

    This is excellent - it manually computes the expected loss step-by-step and compares. This is exactly the kind of test that would catch bugs in the core logic.

Documentation

  1. Missing edge case documentation

    The docstring for use_matryoshka_aux_loss (matryoshka_batchtopk_sae.py:40-42) mentions it's "slower to train" but doesn't quantify or explain why. Adding "This requires computing partial reconstructions for each level during aux loss calculation" would help users understand the tradeoff.

Minor Issues

  1. Import organization (matryoshka_batchtopk_sae.py:13-17)

    The imports from topk_sae are now multiple. Consider whether these functions should be in a shared utilities module if they're being reused across SAE variants.

  2. Magic number (matryoshka_batchtopk_sae.py:162)

    k_aux = sae_in.shape[-1] // 2 - The division by 2 is unexplained. This matches the base class behavior, but a comment explaining why k_aux is half of d_in would improve readability.

Security/Safety

No security concerns identified. This is purely computational code with no external I/O or user input.

Summary

Recommendation: Request changes to fix the critical issue with empty aux_losses list (point 2). The performance concern (point 3) is acknowledged in the TODO but is significant for production use.

The test coverage is excellent and the feature is well-designed. Once the empty list bug is fixed, this would be ready to merge.

Severity Assessment:

  • Critical (must fix): Issue 2 (empty aux_losses)
  • High (should fix): Issue 3 (performance - but has TODO acknowledging it)
  • Medium (nice to have): Issues 1, 4, 9, 11
  • Low (optional): Issues 5, 6, 10

Great work overall! The test suite gives high confidence in the correctness of the implementation.

@chanind chanind merged commit f5c5110 into main Feb 13, 2026
5 checks passed
@chanind chanind deleted the matryoshka-aux-loss branch February 13, 2026 13:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant