Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 114 additions & 24 deletions sae_lens/saes/matryoshka_batchtopk_sae.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from dataclasses import dataclass, field
from typing import Generator

import torch
from typing_extensions import override
Expand All @@ -9,7 +10,11 @@
BatchTopKTrainingSAEConfig,
)
from sae_lens.saes.sae import TrainStepInput, TrainStepOutput
from sae_lens.saes.topk_sae import _sparse_matmul_nd
from sae_lens.saes.topk_sae import (
_sparse_matmul_nd,
act_times_W_dec,
calculate_topk_aux_acts,
)


@dataclass
Expand All @@ -32,6 +37,9 @@ class MatryoshkaBatchTopKTrainingSAEConfig(BatchTopKTrainingSAEConfig):
topk_threshold_lr (float): Learning rate for updating the global topk threshold.
The threshold is updated using an exponential moving average of the minimum
positive activation value. Defaults to 0.01.
use_matryoshka_aux_loss (bool): Whether to encourage dead latents to reconstruct the error
of just their own level rather than the error of the entire SAE. This should result in
better feature revival, but is slower to train. Defaults to False.
aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
Defaults to 1.0.
Expand All @@ -50,6 +58,7 @@ class MatryoshkaBatchTopKTrainingSAEConfig(BatchTopKTrainingSAEConfig):
"""

matryoshka_widths: list[int] = field(default_factory=list)
use_matryoshka_aux_loss: bool = False

@override
@classmethod
Expand All @@ -74,15 +83,37 @@ def __init__(
super().__init__(cfg, use_error_term)
_validate_matryoshka_config(cfg)

def _iterable_decode(
self, feature_acts: torch.Tensor, include_outer_loss: bool = False
) -> Generator[tuple[int, torch.Tensor], None, None]:
if self.cfg.rescale_acts_by_decoder_norm:
# need to multiply by the inverse of the norm because division is illegal with sparse tensors
inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
feature_acts = feature_acts * inv_W_dec_norm
widths = self.cfg.matryoshka_widths
prev_width = 0
if not include_outer_loss:
widths = widths[:-1]
decoded = self.b_dec
for width in widths:
inner_feature_acts = feature_acts[:, prev_width:width]
if inner_feature_acts.is_sparse:
decoded = (
_sparse_matmul_nd(inner_feature_acts, self.W_dec[prev_width:width])
+ decoded
)
else:
decoded = inner_feature_acts @ self.W_dec[prev_width:width] + decoded
prev_width = width
yield width, self.run_time_activation_norm_fn_out(decoded)

@override
def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
base_output = super().training_forward_pass(step_input)
inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
# the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
for width in self.cfg.matryoshka_widths[:-1]:
inner_reconstruction = self._decode_matryoshka_level(
base_output.feature_acts, width, inv_W_dec_norm
)
for width, inner_reconstruction in self._iterable_decode(
base_output.feature_acts, include_outer_loss=False
):
inner_mse_loss = (
self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
.sum(dim=-1)
Expand All @@ -92,28 +123,87 @@ def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
base_output.loss = base_output.loss + inner_mse_loss
return base_output

def _decode_matryoshka_level(
@override
def calculate_aux_loss(
self,
step_input: TrainStepInput,
feature_acts: torch.Tensor,
width: int,
inv_W_dec_norm: torch.Tensor,
) -> torch.Tensor:
"""
Decodes feature activations back into input space for a matryoshka level
"""
inner_feature_acts = feature_acts[:, :width]
# Handle sparse tensors using efficient sparse matrix multiplication
if self.cfg.rescale_acts_by_decoder_norm:
# need to multiply by the inverse of the norm because division is illegal with sparse tensors
inner_feature_acts = inner_feature_acts * inv_W_dec_norm[:width]
if inner_feature_acts.is_sparse:
sae_out_pre = (
_sparse_matmul_nd(inner_feature_acts, self.W_dec[:width]) + self.b_dec
hidden_pre: torch.Tensor,
sae_out: torch.Tensor,
) -> dict[str, torch.Tensor]:
# Calculate the auxiliary loss for dead neurons
if self.cfg.use_matryoshka_aux_loss:
aux_loss = self.calculate_matryoshka_aux_loss(
sae_in=step_input.sae_in,
sae_out=sae_out,
feature_acts=feature_acts,
hidden_pre=hidden_pre,
dead_neuron_mask=step_input.dead_neuron_mask,
)
else:
sae_out_pre = inner_feature_acts @ self.W_dec[:width] + self.b_dec
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
return self.reshape_fn_out(sae_out_pre, self.d_head)
aux_loss = self.calculate_topk_aux_loss(
sae_in=step_input.sae_in,
sae_out=sae_out,
hidden_pre=hidden_pre,
dead_neuron_mask=step_input.dead_neuron_mask,
)
return {"auxiliary_reconstruction_loss": aux_loss}

def calculate_matryoshka_aux_loss(
self,
sae_in: torch.Tensor,
sae_out: torch.Tensor,
feature_acts: torch.Tensor,
hidden_pre: torch.Tensor,
dead_neuron_mask: torch.Tensor | None,
) -> torch.Tensor:
# calculate a separate aux loss for each new matryoshka portion of the SAE
if dead_neuron_mask is not None and int(dead_neuron_mask.sum()) > 0:
k_aux = sae_in.shape[-1] // 2
prev_width = 0
aux_losses = []

# Normalize decoder weights once to avoid repeated computation across levels
scaled_W_dec = (
self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
if self.cfg.rescale_acts_by_decoder_norm
else self.W_dec
)
# TODO: find a way to implement this without needing to recalculate the SAE output for each level
# may need to wait for a refactor in the next release of sae_lens for a clean way to do this
for width, partial_sae_out in self._iterable_decode(
feature_acts, include_outer_loss=True
):
partial_dead_neuron_mask = dead_neuron_mask[prev_width:width]
partial_num_dead = int(partial_dead_neuron_mask.sum())
if partial_num_dead == 0:
prev_width = width
continue

# Reduce the scale of the loss if there are a small number of dead latents
scale = min(partial_num_dead / k_aux, 1.0)
partial_k_aux = min(k_aux, partial_num_dead)
partial_hidden_pre = hidden_pre[:, prev_width:width]
residual = (sae_in - partial_sae_out).detach()
auxk_acts = calculate_topk_aux_acts(
k_aux=partial_k_aux,
hidden_pre=partial_hidden_pre,
dead_neuron_mask=partial_dead_neuron_mask,
)

# Encourage the top ~50% of dead latents to predict the residual of the
# top k living latents
recons = act_times_W_dec(
auxk_acts,
scaled_W_dec[prev_width:width],
rescale_acts_by_decoder_norm=False,
)
auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
aux_losses.append(scale * auxk_loss)
prev_width = width
stacked_losses = torch.stack(aux_losses)
return self.cfg.aux_loss_coefficient * stacked_losses.sum()
return sae_out.new_tensor(0.0)


def _validate_matryoshka_config(cfg: MatryoshkaBatchTopKTrainingSAEConfig) -> None:
Expand Down
12 changes: 6 additions & 6 deletions sae_lens/saes/topk_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _sparse_matmul_nd(
return result_2d.view(result_shape)


def _act_times_W_dec(
def act_times_W_dec(
feature_acts: torch.Tensor,
W_dec: torch.Tensor,
rescale_acts_by_decoder_norm: bool,
Expand Down Expand Up @@ -273,7 +273,7 @@ def decode(
and optional head reshaping.
"""
sae_out_pre = (
_act_times_W_dec(
act_times_W_dec(
feature_acts, self.W_dec, self.cfg.rescale_acts_by_decoder_norm
)
+ self.b_dec
Expand Down Expand Up @@ -392,7 +392,7 @@ def decode(
applying optional finetuning scale, hooking, out normalization, etc.
"""
sae_out_pre = (
_act_times_W_dec(
act_times_W_dec(
feature_acts, self.W_dec, self.cfg.rescale_acts_by_decoder_norm
)
+ self.b_dec
Expand Down Expand Up @@ -488,7 +488,7 @@ def calculate_topk_aux_loss(
scale = min(num_dead / k_aux, 1.0)
k_aux = min(k_aux, num_dead)

auxk_acts = _calculate_topk_aux_acts(
auxk_acts = calculate_topk_aux_acts(
k_aux=k_aux,
hidden_pre=hidden_pre,
dead_neuron_mask=dead_neuron_mask,
Expand All @@ -497,7 +497,7 @@ def calculate_topk_aux_loss(
# Encourage the top ~50% of dead latents to predict the residual of the
# top k living latents. Per the paper (Appendix A.2), the reconstruction
# is ê = W_dec @ z (no bias), since b_dec is already in the residual.
recons = _act_times_W_dec(
recons = act_times_W_dec(
auxk_acts, self.W_dec, self.cfg.rescale_acts_by_decoder_norm
)
# Apply the same reshaping as decode() so recons matches the residual's shape
Expand All @@ -518,7 +518,7 @@ def process_state_dict_for_saving_inference(
)


def _calculate_topk_aux_acts(
def calculate_topk_aux_acts(
k_aux: int,
hidden_pre: torch.Tensor,
dead_neuron_mask: torch.Tensor,
Comment on lines +521 to 524
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renaming _calculate_topk_aux_acts to calculate_topk_aux_acts is a breaking change for any internal/external callers that still import the underscored name. To preserve compatibility, consider keeping _calculate_topk_aux_acts as a thin alias to calculate_topk_aux_acts (optionally with a deprecation warning) for at least one release cycle.

Copilot uses AI. Check for mistakes.
Expand Down
7 changes: 3 additions & 4 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def build_matryoshka_batchtopk_runner_cfg(
**kwargs: Any,
) -> LanguageModelSAERunnerConfig[MatryoshkaBatchTopKTrainingSAEConfig]:
"""Helper to create a mock instance for Matryoshka BatchTopK SAE."""
default_sae_config: TrainingSAEConfigDict = {
default_sae_config: dict[str, Any] = {
"matryoshka_widths": [10, kwargs.get("d_sae", 20)],
"d_in": 64,
"d_sae": 256,
Expand All @@ -586,18 +586,17 @@ def build_matryoshka_batchtopk_runner_cfg(
"apply_b_dec_to_input": False,
"k": 10,
"topk_threshold_lr": 0.02,
"use_matryoshka_aux_loss": False,
}
# Ensure activation_fn_kwargs has k if k is overridden
temp_sae_overrides = {
k: v for k, v in kwargs.items() if k in TrainingSAEConfigDict.__annotations__
}
temp_sae_config = {**default_sae_config, **temp_sae_overrides}
# Update the default config *before* passing it to _build_runner_config
final_default_sae_config = cast(dict[str, Any], temp_sae_config)

runner_cfg = _build_runner_config(
MatryoshkaBatchTopKTrainingSAEConfig,
final_default_sae_config,
temp_sae_config,
**kwargs,
)
_update_sae_metadata(runner_cfg)
Expand Down
Loading
Loading