Skip to content

neverix/saex

Repository files navigation

saex

Sparse autoencoders in Jax.

Running

# Train a small SAE on the GPT-2 residual stream. Requires at most 32GB of RAM.
python -m scripts.train_gpt2_sae --is_xl=False --save_steps=0 --sparsity_coefficient=1e-4
# Download GPT-2 residual stream SAEs for finetuning
scripts/download_jb_saes.sh
# Download Gemma 2B and Phi-3 mini
mkdir -p weights
wget -c 'https://huggingface.co/mlabonne/gemma-2b-GGUF/resolve/main/gemma-2b.Q8_0.gguf?download=true' -O weights/gemma-2b.gguf
wget -c 'https://huggingface.co/mlabonne/gemma-2b-it-GGUF/resolve/main/gemma-2b-it.Q8_0.gguf?download=true' -O weights/gemma-2b-it.gguf
wget 'https://huggingface.co/SanctumAI/Phi-3-mini-4k-instruct-GGUF/resolve/main/phi-3-mini-4k-instruct.fp16.gguf?download=true' -O weights/phi-3-16.gguf
# Generate data for a toy model
JAX_PLATFORMS=cpu python -m saex.toy_models
# Train Phi and Gemma SAEs
nohup python train_phis.py &
nohup python train_gemmas.py &
# Feature visualization
python -m scripts.cache_features
gradio scripts/feature_visualizer.py

Tests (there aren't any yet):

poetry run pytest

How to install

sudo apt install -y make build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev xz-utils tk-dev libffi-dev liblzma-dev python-openssl git
echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc
echo 'export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc

echo -e 'if command -v pyenv 1>/dev/null 2>&1; then\n eval "$(pyenv init -)"\nfi' >> ~/.bashrc
pyenv install 3.12.3
pyenv global 3.12.3
python3 -m pip install poetry
echo 'export PATH="$PYENV_ROOT/versions/3.12.3/bin:$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc

poetry env use 3.12
poetry lock
poetry install
poetry shell

I think it should be possible to set up a development environment without installing pyenv on tpu-ubuntu2204.

FAQ

No one actually asked these questions, but here are the answers anyway.

How is this code parallelized?

Data and tensor parallelism. In theory, the size of the SAE is unlimited. In practice, it is initialized on one device.

Are results comparable to SAELens?

Yes. I haven't tested with smaller batch sizes, but you can get comparable results for GPT2-Small Layer 9 with ~25% less tokens and ~3x lower training time.

What techniques does saex use?

About

SAEs in Jax

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published