Questions
T-SAE paper: https://openreview.net/pdf?id=bojVI4l9Kn (confusingly, this is different from Temporal SAE)
I am working on developing an architecture similar to but slightly different from the above paper. In the end, I would like to be able to implement my architecture as well as T-SAE as a baseline. The T-SAE architecture itself doesn't seem too hard to implement, just using an InfoNCE loss between tokens.
My big issue is figuring out how to properly modify the activation store class, in a way that is both general enough to accommodate other new architectures and also functional for both my architecture and T-SAE.
My main idea to implement the T-SAE activation store would be to return pairs of activations instead of just single activations, so that, during the training loop, one could use the pair as a positive example and the rest of the batch as a negative example.
I have already done a preliminary implementation in my fork of SAELens here: https://github.com/xXCoolinXx/SAELens/blob/context-sae/sae_lens/training/activations_store.py
The main changes are that I added a context_processor function argument to the activation store, that can be called before the full context activations are returned. For example, the near_pairs context processor can find pairs between different activations in the context. The issues I've had with this approach is that it doesn't work with concatenated batches, and I've additionally had to pass in a mask to avoid padding. It also gets a bit messy trying to work with the paired activations in the downstream SAE.
I would like to find a method that is general enough to meet this use case and also be suitable to commit the changes to the library. My question would then be, what is the preferred approach for this goal?
Questions
T-SAE paper: https://openreview.net/pdf?id=bojVI4l9Kn (confusingly, this is different from Temporal SAE)
I am working on developing an architecture similar to but slightly different from the above paper. In the end, I would like to be able to implement my architecture as well as T-SAE as a baseline. The T-SAE architecture itself doesn't seem too hard to implement, just using an InfoNCE loss between tokens.
My big issue is figuring out how to properly modify the activation store class, in a way that is both general enough to accommodate other new architectures and also functional for both my architecture and T-SAE.
My main idea to implement the T-SAE activation store would be to return pairs of activations instead of just single activations, so that, during the training loop, one could use the pair as a positive example and the rest of the batch as a negative example.
I have already done a preliminary implementation in my fork of SAELens here: https://github.com/xXCoolinXx/SAELens/blob/context-sae/sae_lens/training/activations_store.py
The main changes are that I added a
context_processorfunction argument to the activation store, that can be called before the full context activations are returned. For example, thenear_pairscontext processor can find pairs between different activations in the context. The issues I've had with this approach is that it doesn't work with concatenated batches, and I've additionally had to pass in a mask to avoid padding. It also gets a bit messy trying to work with the paired activations in the downstream SAE.I would like to find a method that is general enough to meet this use case and also be suitable to commit the changes to the library. My question would then be, what is the preferred approach for this goal?