Note: This is an initial release of OpenSAE. We are actively working on improving the documentation and adding more features. Please stay tuned for updates in the future.
This project is maintained with setuptools, so you can install it via pip directly. It requires Python 3.12
(or higher version).
git clone [email protected]:THU-KEG/OpenSAE.git
cd OpenSAE
pip install -e .
Additionally, we provide a Docker image for the project. You can build the image by running the following command:
docker push transirius/sae:latest
docker run --gpus all \
-it --rm -d \
--name sae \
-v {SAE CHECKPOINTS DIR}:/CHECKPOINTS \
-v {TRAINING DATA DIR}:/DATA \
-v {LLM CHECKPOINTS DIR}:/MODELS \
sae:latest \
The Docker image is built on top of the nvcr.io/nvidia/pytorch:24.02-py3
image.
We inject a miniconda environment with the required dependencies into the image.
We also release OpenSAE, a large-scale pre-trained SAE for LLaMA-3.1-8B.
In particular, our released SAEs are pre-trained on the residual-stream of LLaMA-3.1-8B.
They are pre-trained using 22B tokens, with context window size extended to 4096 tokens.
The released OpenSAE projects the hidden states of the LLaMA-3.1-8B to a high-dimensional space with 262,144
features, which is 64x larger than the hidden size of the LLaMA-3.1-8B.
As far as we know, OpenSAE is the most large-scale pre-trained SAE model that is released to the public in terms of the context length, expansion ratio, and size of training corpora.
OpenSAE allows to load the SAE with only one line of code:
from opensae import OpenSae
OpenSae.from_pretrained("/dir/to/sae")
An SAE model comprise two key components: An encoder, which maps the input hidden to the high dimensional space with sparse activation; and a decoder, which decodes the sparse activation to reconstruct the hidden.
In OpenSAE, we implement the following interfaces:
This method implement the encoder forward pass.
input
- hidden:
torch.Tensor
, required. Shape = (tokens, hidden_size). To process multiple sentences in a batch, this method requires to flatten the tokens in the batch. - return_all_features:
bool
, optional, default toFalse
. When set toTrue
, by callingencode()
will reture all the features before sparse activation in the output class.
output
- SaeEncoderOutput:
OrderedDict
. Fields include:- sparse_feature_activations: The activation value of the sparse features in SAE after sparse activation.
- sparse_feature_indices: The indices of activated features.
- all_features: All the features before the sparse activation, which means hidden_size
$\times$ expansion_ratio features per token. - input_mean: The average of
hidden
for LayerNorm. - input_std: The standard deviationof
hidden
for LayerNorm.
This method implement the decoder forward pass.
input
- feature_indices:
torch.LongTensor
, required. Shape = (tokens, num of sparse features). The sparse feature activation. Usually fromSaeEncoderOutput.sparse_feature_indices
. When use TopK activation, the num of sparse features is K. - feature_activation:
torch.FloatTensor
, required. Shape = (tokens, num of sparse features) The indices of the sparse feature activation. Usually fromSaeEncoderOutput.sparse_feature_activations
. - input_mean:
torch.FloatTensor
, optional. Shape = (tokens,). The average ofhidden
for LayerNorm. This is required when the SAE model performs shift_back. - input_std:
torch.FloatTensor
, optional. Shape = (tokens,). The standard deviation ofhidden
for LayerNorm. This is required when the SAE model performs shift_back.
output
- SaeDecoderOutput:
OrderedDict
. Fields include:- sae_output: The reconstruction for the input hidden.
This method combines the encoding operation and the decoding operation. It also calculates all the nessary loss for training.
input
- hidden:
torch.Tensor
, required. Shape = (tokens, hidden_size). To process multiple sentences in a batch, this method requires to flatten the tokens in the batch. - dead_mask:
torch.Tensor
, required. Shape = (num of sparse features). Used to calculate the Auxilary-K loss.
output
- SaeForwardOutput:
OrderedDict
. Fields include:- sparse_feature_activations: The activation value of the sparse features in SAE after sparse activation.
- sparse_feature_indices: The indices of activated features.
- all_features: All the features before the sparse activation, which means hidden_size
$\times$ expansion_ratio features per token. - input_mean: The average of
hidden
for LayerNorm. - input_std: The standard deviationof
hidden
for LayerNorm. - sae_output: The reconstruction for the input hidden.
- reconstruction_loss: The reconstruction loss, which is the l2 loss between the input hidden and the reconstruction.
- auxk_loss: The Auxilary-K loss.
- multi_topk_loss: The Multi-TopK loss.
- l1_loss: The L1 loss.
- loss: The total loss, which is the weighted sum of the reconstruction loss, auxk loss, multi-topk loss, and l1 loss.
To bind the sae with an LLM, we privide the TransformerWithSae
class.
To initialize the class, you need to pass the SAE model and the LLM model to the class.
from opensae import TransformerWithSae, InterventionConfig
layer_num = 12
model = TransformerWithSae(
"/MODELS/Meta-Llama-3.1-8B",
f"/SAE/OpenSAE-LLaMA-3.1-Layer_{layer_num:02d}",
device
)
The TransformerWithSae
class will automatically bind the SAE with the LLM by registering the encoding and decoding operations to the forward pass of the LLM.
The intervention operation is controlled by the InterventionConfig class.
The intertention config can be passed to TransformerWithSae
when initialize the class.
It can also be altered by calling the update_intervention_config
method after the TransformerWithSae class is already instantiated.
We introduce the intervention config below:
- prompt_only:
bool
, optional, default toFalse
. When set toTrue
, the SAE is only applied to the prompts in the prefilling stage. The SAE will not by applied to the generated tokens during the generation stage. - intervention:
bool
, optional, default toFalse
. When set toTrue
, the sparse activation value will be altered according tointervention_mode
,intervention_indices
, andintervention_value
. Otherwise, the hidden is replaced by the reconstruction directly, without altering the sparse activations. - intervention_mode:
str
, optional, default toset
. Select fromset
,add
, andmultiply
. set means that the sparse activation values according to theintervention_indices
is set to theintervention_value
. add means thatintervention_value
will be added to the sparse activation values. multiply means that we multiply the sparse activation value by the factorintervention_value
inintervention_indices
. - intervention_indices:
List[int] | None
, optional, default toNone
. It specifies which features are intervened. - intervention_value:
float
, optional, default to0.0
. The intervention value.
Stay Tuned.
Stay Tuned.
Stay Tuned.
We provide our training pipeline in the train
module.
The training pipeline is implemented in train.py
, you can use the trainer by running the following command:
torchrun --nproc_per_node $((mp_size * dp_size)) \
--master-port $((10000 + $RANDOM % 100)) \
-m opensae.trainer \
${MODEL_CONFIG} ${DATA_CONFIG} \
${TRAIN_CONFIG} ${SAE_CONFIG} \
--run_name $exp_name
Please see examples/train.sh
for more details.
Note: with a single GPU equipped with 80GiB memory, the training infra of OpenSAE can train a SAE model with 262,144 features with a context length of 4,096 tokens for a 8B LLM. Using H100, the training process takes around 30 days to converge for 20B training tokens.
- To support more Open-sourced SAEs, including: LLaMA-Scope, and Gemma-Scope
- To further optimize our training infra.
This project draws inspiration from various third-party SAE (Sparse Autoencoder) tools. We would like to express our heartfelt gratitude to the following:
- transformers by HuggingFace: The architecture design of OpenSAE is significantly influenced by their implementation.
- sparse_autoencoder by OpenAI: We have adapted their kernel implementation and some of their training tricks.
- sae by EleutherAI: Our training pipeline is largely inspired by their work. We deeply appreciate the contributions of these projects to the open-source community, which have been invaluable to the development of this project.
If you find this project helpful, please consider citing our paper:
@article{opensae,
title={OpenSAE: Open-sourced Sparse Auto-Encoder towards Interpreting Large Language Models},
author={THU-KEG},
year={2025}
}