Skip to content

bobby-he/simplified_transformers

Repository files navigation

Simplified Transformers

This is the author's implementation for Simplifying Transformer Blocks (ICLR 2024) and Understanding and Minimising Outlier Features in Transformer Training (NeurIPS 2024).

Getting started

The main dependencies for this repo are:

  • hydra
  • wandb
  • torch
  • transformers (hf)
  • datasets (hf)
  • evaluate (hf)
  • accelerate (hf)

To install these dependencies, run: pip install -r requirements.txt.

Usage

This codebase runs the autoregressive experiments in our paper. The main training script is run_clm.py, which trains GPT-2 (small, ~120M params) on next-token prediction using code data, largely inspired by this HF notebook. It may take a few minutes to download the data on the first run.

We use hydra to organise our configs, so all arguments can be set from the command line. We assume training takes place on a single GPU.

The default config uses Pre-LN GPT-2, i.e. running:

python run_clm.py num_token_mult=2 model.n_layer=18

reproduces the Pre-LN run in Figure 2 of the paper, and should obtain eval loss of ~1.155 after 40K training steps. This takes ~10 hours on a A5000.

To change the model, we have 3 non-default configs set up from which you can make modifications:

  1. default-parallel (parallel block from GPT-J),
  2. skipless (without attention sub-block skip, Figure 9 of paper)
  3. skipless-parallel (parallel and skipless, Figure 10 of paper)

Other model settings can be customised from command line. For example, the following command reproduces the parallel, skipless block without normalisation (i.e. top right in header figure) in Figure 5:

python run_clm.py num_token_mult=2 model.n_layer=18 model=skipless-parallel model.norm_type=none

which should obtain eval loss of eval loss of ~1.245 after 40K steps. More training scripts can be found in exp_scripts/.

We use wandb for logging by default. To turn this off, simply add use_wandb=False on command line.

Outlier Feature computation

The kurtosis computation for outlier features can be found here.

Note we take the variance (not second moment) of the normalised neuron-wise activation squared mean in here which means we compute $kurtosis-1$. This is out by an additive constant of 1, but doesn't change our findings regarding preventing outlier features.

The config for the OP block is here.

Citation

If you found this codebase useful, please consider citing:

@inproceedings{
he2024simplifying,
title={Simplifying Transformer Blocks},
author={Bobby He and Thomas Hofmann},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=RtDok9eS3s}
}
@inproceedings{
he2024understanding,
title={Understanding and Minimising Outlier Features in Transformer Training},
author={Bobby He and Lorenzo Noci and Daniele Paliotta and Imanol Schlag and Thomas Hofmann},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=npJQ6qS4bg}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published