Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some fast Metal MLX SDPA kernels #2584

Merged
merged 10 commits into from
Nov 5, 2024

Conversation

EricLBuehler
Copy link
Member

@EricLBuehler EricLBuehler commented Oct 29, 2024

This PR adds some MLX SDPA kernels on Metal.

I can observe about a 26% performance improvement with Llama 3.1 8b @ q4k and @ q8_0 when testing through mistral.rs on my Candle fork. I updated the quantized_llama.rs file here to use the new function.

This PR adds a function candle_nn::ops::sdpa. The MLX attention kernels don't support masking yet, so the performance gains are only for decoding on Metal. Once/if they do, I'll update them - otherwise we can explore using Flash Attention kernels for Metal from llama.cpp.

EricLBuehler and others added 3 commits October 29, 2024 06:34
* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format
.vscode/settings.json Outdated Show resolved Hide resolved
candle-metal-kernels/src/lib.rs Outdated Show resolved Hide resolved
candle-metal-kernels/src/lib.rs Outdated Show resolved Hide resolved
candle-metal-kernels/src/lib.rs Outdated Show resolved Hide resolved
candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`");
}

let k_head = k_l.dims()[k_l.dims().len() - 1];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k_l.dim(D::Minus1)? would be simpler and make for a better error message than a panic (the same applies to a bunch of places in this function)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure Layout::dim exists, that is why I used Layout::dims? Perhaps we could add it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can get the shape which should have the dim methods.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the Shape also doesn't have the dim methods?

https://github.com/search?q=repo%3Ahuggingface%2Fcandle+%22fn+dim%22&type=code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, I've added it quickly as part of this PR, at least it should give better error messages rather than just out of bounds.

candle-nn/src/ops.rs Outdated Show resolved Hide resolved
@EricLBuehler
Copy link
Member Author

Latest commit fixes a bug reported in mlx, see ml-explore/mlx#1558.

@LaurentMazare LaurentMazare merged commit e2b6b36 into huggingface:main Nov 5, 2024
9 of 10 checks passed
@EricLBuehler
Copy link
Member Author

Thank you!

@EricLBuehler EricLBuehler deleted the candle_mlx_sdpa branch November 6, 2024 21:35
EricLBuehler added a commit to EricLBuehler/candle that referenced this pull request Nov 26, 2024
* Add some fast Metal MLX SDPA kernels (#32)

* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format

* Conditional compilation for bf16

* Use it in quantized llama

* Some review comments

* Use set_params!

* Remove unused

* Remove feature

* Fix metal sdpa for v stride

* Remove comma

* Add the dim method to layout and shape.

---------

Co-authored-by: Laurent <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants