Sparse autoencoders in Jax.
# Train a small SAE on the GPT-2 residual stream. Requires at most 32GB of RAM.
python -m scripts.train_gpt2_sae --is_xl=False --save_steps=0 --sparsity_coefficient=1e-4
# Download GPT-2 residual stream SAEs for finetuning
scripts/download_jb_saes.sh
# Download Gemma 2B and Phi-3 mini
mkdir -p weights
wget -c 'https://huggingface.co/mlabonne/gemma-2b-GGUF/resolve/main/gemma-2b.Q8_0.gguf?download=true' -O weights/gemma-2b.gguf
wget -c 'https://huggingface.co/mlabonne/gemma-2b-it-GGUF/resolve/main/gemma-2b-it.Q8_0.gguf?download=true' -O weights/gemma-2b-it.gguf
wget 'https://huggingface.co/SanctumAI/Phi-3-mini-4k-instruct-GGUF/resolve/main/phi-3-mini-4k-instruct.fp16.gguf?download=true' -O weights/phi-3-16.gguf
# Generate data for a toy model
JAX_PLATFORMS=cpu python -m saex.toy_models
# Train Phi and Gemma SAEs
nohup python train_phis.py &
nohup python train_gemmas.py &
# Feature visualization
python -m scripts.cache_features
gradio scripts/feature_visualizer.py
Tests (there aren't any yet):
poetry run pytest
sudo apt install -y make build-essential libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev xz-utils tk-dev libffi-dev liblzma-dev python-openssl git
echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc
echo 'export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc
echo -e 'if command -v pyenv 1>/dev/null 2>&1; then\n eval "$(pyenv init -)"\nfi' >> ~/.bashrc
pyenv install 3.12.3
pyenv global 3.12.3
python3 -m pip install poetry
echo 'export PATH="$PYENV_ROOT/versions/3.12.3/bin:$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc
poetry env use 3.12
poetry lock
poetry install
poetry shell
I think it should be possible to set up a development environment without installing pyenv on tpu-ubuntu2204.
No one actually asked these questions, but here are the answers anyway.
How is this code parallelized?
Data and tensor parallelism. In theory, the size of the SAE is unlimited. In practice, it is initialized on one device.
Are results comparable to SAELens?
Yes. I haven't tested with smaller batch sizes, but you can get comparable results for GPT2-Small Layer 9 with ~25% less tokens and ~3x lower training time.
What techniques does
saex
use?