MLX port for xjdr's entropix sampler for LLMs. This port tries to mimic xjdr's jax implementation as closely as possible using only MLX operations. It is possible that the current implementation is unstable and unoptimized (PRs welcomed).
This repository is for research purposes. It uses only mlx
for operations, and is not optimized for production applications.
This port uses only MLX for the main operations, and pytorch is used only for correctly loading the weights.
pip install -r requirements.txt
You must download the weights for llama3.2 1B model. If you have already downloaded them, skip this step.
python download_weights.py --model-id meta-llama/Llama-3.2-1B-Instruct --out-dir weights/1B-Instruct --hf_token <your-huggingface-token-here>
To run the model with entropix sampler, on your input prompt (for whatever research purposes):
python main.py --input "Which is greater, 9.9 or 9.11?"
You might see colored tokens which the LLM generates in the output.
- Green: Low entropy, Low varentropy
- Red: High entropy, High varentropy
- Magenta: High entropy, Low varentropy
- Yellow: Low entropy, High varentropy
- No color: Adaptive sampling (general case)