Conversation
PR Review: Matryoshka Auxiliary LossThis PR adds an optional auxiliary loss for Matryoshka SAEs that encourages dead features to reconstruct the residual of their own level rather than the entire SAE. The implementation is well-structured with comprehensive test coverage. Here's my feedback: Code Quality & Best PracticesPositive aspects:
Issues to address:
Test CoverageExcellent test coverage:
Test quality issues:
Performance Considerations
Security & CorrectnessCritical issue: The bias term addition in the aux loss reconstruction (point 3 above) appears incorrect and will lead to wrong gradient signals for dead feature revival. This needs verification and likely fixing. Minor issues:
RecommendationsMust fix:
Should fix:
Nice to have:
SummaryThis is a solid implementation with thorough testing. The main concern is the potential bias term error in the auxiliary loss reconstruction, which could affect the correctness of dead feature revival. Once that's verified/fixed, this should be ready to merge. |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
4e3eeaa to
424e190
Compare
There was a problem hiding this comment.
Pull request overview
Adds an optional Matryoshka-specific auxiliary loss for BatchTopK SAEs, where dead features are trained to reconstruct the level-local residual rather than the full SAE residual.
Changes:
- Introduces
use_matryoshka_aux_lossconfig flag and Matryoshka-specific aux-loss implementation. - Exposes
calculate_topk_aux_actsas a public helper and updates call sites. - Adds unit tests covering manual computation, residual targeting behavior, gradient flow, and flag-off parity.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 9 comments.
| File | Description |
|---|---|
| sae_lens/saes/matryoshka_batchtopk_sae.py | Adds iterable level decoding + new Matryoshka aux loss behind a config flag. |
| sae_lens/saes/topk_sae.py | Renames/exports aux-acts helper to reuse from Matryoshka implementation. |
| tests/saes/test_matryoshka_batchtopk_sae.py | Adds tests validating new aux-loss behavior + gradients + parity. |
| tests/helpers.py | Updates helper config typing and adds default for new flag. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ) | ||
|
|
||
| # Reduce the scale of the loss if there are a small number of dead latents | ||
| scale = min(partial_num_dead / partial_k_aux, 1.0) |
There was a problem hiding this comment.
The scaling factor is effectively always 1.0 because partial_k_aux = min(k_aux, partial_num_dead) makes partial_num_dead / partial_k_aux >= 1. In calculate_topk_aux_loss the scale is computed before clamping k_aux to num_dead. To match that intended behavior (and your tests' stated intent), compute scale using the unclamped k_aux (or compute scale before setting partial_k_aux).
| scale = min(partial_num_dead / partial_k_aux, 1.0) | |
| scale = min(partial_num_dead / k_aux, 1.0) |
| # top k living latents | ||
| recons = auxk_acts @ self.W_dec[prev_width:width] |
There was a problem hiding this comment.
The aux-loss reconstruction path is inconsistent with the Matryoshka decode path used to compute partial_sae_out (via _iterable_decode). _iterable_decode may rescale activations by decoder norm and applies run_time_activation_norm_fn_out (and starts from b_dec), but recons here does not. This makes the target residual and the predicted reconstruction live in different 'decode conventions', and also ignores rescale_acts_by_decoder_norm for the aux path. Consider constructing recons using the same conventions as _iterable_decode for a single level (including any act rescaling and output activation-norm steps, and whether b_dec should be included) so that (recons - residual) is meaningful and consistent across configs.
| # top k living latents | |
| recons = auxk_acts @ self.W_dec[prev_width:width] | |
| # top k living latents. Use the same decode conventions as the main | |
| # Matryoshka path (activation rescaling, bias, and output norm). | |
| scaled_auxk_acts = auxk_acts | |
| if getattr(self.cfg, "rescale_acts_by_decoder_norm", False) and hasattr( | |
| self, "decoder_norms" | |
| ): | |
| scaled_auxk_acts = scaled_auxk_acts * self.decoder_norms[ | |
| prev_width:width | |
| ] | |
| recons = scaled_auxk_acts @ self.W_dec[prev_width:width] | |
| if hasattr(self, "b_dec") and self.b_dec is not None: | |
| recons = recons + self.b_dec | |
| if hasattr(self, "run_time_activation_norm_fn_out") and ( | |
| self.run_time_activation_norm_fn_out is not None | |
| ): | |
| recons = self.run_time_activation_norm_fn_out(recons) |
| inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1) | ||
| if self.cfg.rescale_acts_by_decoder_norm: |
There was a problem hiding this comment.
inv_W_dec_norm is computed unconditionally even when rescale_acts_by_decoder_norm is False. This adds an extra norm+division on every call, and can produce inf when rows of W_dec are zero (your tests set W_dec[0:3] = 0.0). Compute inv_W_dec_norm only inside the if, and consider clamping the norm with an epsilon to avoid inf when rescaling is enabled.
| inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1) | |
| if self.cfg.rescale_acts_by_decoder_norm: | |
| if self.cfg.rescale_acts_by_decoder_norm: | |
| # Clamp norm to avoid division by zero when some rows of W_dec are zero. | |
| W_dec_norm = self.W_dec.norm(dim=-1) | |
| inv_W_dec_norm = 1.0 / torch.clamp(W_dec_norm, min=1e-12) |
| def calculate_topk_aux_acts( | ||
| k_aux: int, | ||
| hidden_pre: torch.Tensor, | ||
| dead_neuron_mask: torch.Tensor, |
There was a problem hiding this comment.
Renaming _calculate_topk_aux_acts to calculate_topk_aux_acts is a breaking change for any internal/external callers that still import the underscored name. To preserve compatibility, consider keeping _calculate_topk_aux_acts as a thin alias to calculate_topk_aux_acts (optionally with a deprecation warning) for at least one release cycle.
| partial_dead_mask_1 = dead_neuron_mask[0:3] | ||
| num_dead_1 = int(partial_dead_mask_1.sum()) # = 1 | ||
| partial_k_aux_1 = min(k_aux, num_dead_1) # = 1 | ||
| scale_1 = min(num_dead_1 / partial_k_aux_1, 1.0) # = 1.0 |
There was a problem hiding this comment.
This test claims it exercises scale < 1.0 (comment says 0.5), but the computed scale_1 is 1.0 due to using partial_k_aux_1 (already clamped). If the intended behavior is the same as calculate_topk_aux_loss (scale computed vs the original k_aux), then the test should compute scale_1 against k_aux (yielding 0.5 here) so it actually validates the scaling behavior.
| scale_1 = min(num_dead_1 / partial_k_aux_1, 1.0) # = 1.0 | |
| scale_1 = min(num_dead_1 / k_aux, 1.0) # = 0.5 |
| k_aux = d_in // 2 # = 2 | ||
|
|
||
| # Level 1: features 0:3, 1 dead feature | ||
| partial_sae_out_1 = acts[:, 0:3] @ sae.W_dec.data[0:3] + sae.b_dec.data |
There was a problem hiding this comment.
Using .data in tests bypasses autograd and can mask subtle issues if these tensors ever end up participating in graph construction. Prefer .detach() (or just use the parameter tensors directly when gradients are irrelevant) to make the intent explicit without relying on .data.
| auxk_acts_1 = calculate_topk_aux_acts( | ||
| partial_k_aux_1, partial_hidden_pre_1, partial_dead_mask_1 | ||
| ) | ||
| recons_1 = auxk_acts_1 @ sae.W_dec.data[0:3] |
There was a problem hiding this comment.
Using .data in tests bypasses autograd and can mask subtle issues if these tensors ever end up participating in graph construction. Prefer .detach() (or just use the parameter tensors directly when gradients are irrelevant) to make the intent explicit without relying on .data.
| acts[:, 0:3] @ sae.W_dec.data[0:3] | ||
| + acts[:, 3:6] @ sae.W_dec.data[3:6] | ||
| + sae.b_dec.data |
There was a problem hiding this comment.
Using .data in tests bypasses autograd and can mask subtle issues if these tensors ever end up participating in graph construction. Prefer .detach() (or just use the parameter tensors directly when gradients are irrelevant) to make the intent explicit without relying on .data.
| auxk_acts_2 = calculate_topk_aux_acts( | ||
| partial_k_aux_2, partial_hidden_pre_2, partial_dead_mask_2 | ||
| ) | ||
| recons_2 = auxk_acts_2 @ sae.W_dec.data[3:6] |
There was a problem hiding this comment.
Using .data in tests bypasses autograd and can mask subtle issues if these tensors ever end up participating in graph construction. Prefer .detach() (or just use the parameter tensors directly when gradients are irrelevant) to make the intent explicit without relying on .data.
PR Review: Matryoshka Auxiliary LossThis PR adds an optional auxiliary loss mechanism for Matryoshka SAEs where dead features are encouraged to reconstruct the residual of their own level rather than the entire SAE. Overall, the implementation looks solid with comprehensive tests. Here are my findings: Strengths
Code Quality Issues1. Potential Performance Issue in
|
Code Review: Matryoshka Auxiliary LossThanks for this PR! This adds an interesting new capability for Matryoshka SAEs. I've reviewed the changes and have the following feedback: Strengths
Issues to Address1. Performance concern with
|
Pull Request Review: Matryoshka Auxiliary LossThank you for this contribution! I've reviewed the PR and have the following feedback: Strengths
Issues & SuggestionsCritical Issues
Performance Concerns
Code Quality
Test Quality
Documentation
Minor Issues
Security/SafetyNo security concerns identified. This is purely computational code with no external I/O or user input. SummaryRecommendation: Request changes to fix the critical issue with empty The test coverage is excellent and the feature is well-designed. Once the empty list bug is fixed, this would be ready to merge. Severity Assessment:
Great work overall! The test suite gives high confidence in the correctness of the implementation. |
This PR adds an optional new auxiliary loss for Matryoshka SAEs, where the auxiliary loss tries to revive dead features into the residual of the level they are at, rather than into the residual of the SAE as a whole.