-
Notifications
You must be signed in to change notification settings - Fork 998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add some fast Metal MLX SDPA kernels #2584
Conversation
* Sketch the sdpa kernel * Add full sdpa kernel, * Add test * Add vectorized kernel for decoding * Update tests * Add some docs * Fix sdpa_vector names * Add softcapping for vectorized sdpa * Add softcapping for full sdpa * Add support for head dim 32, 96, 256 * Add support for head dim 32, 96, 256 * Update docs * Add update notice * Clippy and format
candle-nn/src/ops.rs
Outdated
candle::bail!("query `n_heads` must be a multiple of `n_kv_heads`"); | ||
} | ||
|
||
let k_head = k_l.dims()[k_l.dims().len() - 1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
k_l.dim(D::Minus1)?
would be simpler and make for a better error message than a panic (the same applies to a bunch of places in this function)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure Layout::dim
exists, that is why I used Layout::dims
? Perhaps we could add it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can get the shape which should have the dim
methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the Shape also doesn't have the dim methods?
https://github.com/search?q=repo%3Ahuggingface%2Fcandle+%22fn+dim%22&type=code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah right, I've added it quickly as part of this PR, at least it should give better error messages rather than just out of bounds.
Latest commit fixes a bug reported in mlx, see ml-explore/mlx#1558. |
Thank you! |
* Add some fast Metal MLX SDPA kernels (#32) * Sketch the sdpa kernel * Add full sdpa kernel, * Add test * Add vectorized kernel for decoding * Update tests * Add some docs * Fix sdpa_vector names * Add softcapping for vectorized sdpa * Add softcapping for full sdpa * Add support for head dim 32, 96, 256 * Add support for head dim 32, 96, 256 * Update docs * Add update notice * Clippy and format * Conditional compilation for bf16 * Use it in quantized llama * Some review comments * Use set_params! * Remove unused * Remove feature * Fix metal sdpa for v stride * Remove comma * Add the dim method to layout and shape. --------- Co-authored-by: Laurent <[email protected]>
This PR adds some MLX SDPA kernels on Metal.
I can observe about a 26% performance improvement with Llama 3.1 8b @ q4k and @ q8_0 when testing through mistral.rs on my Candle fork. I updated the quantized_llama.rs file here to use the new function.
This PR adds a function
candle_nn::ops::sdpa
. The MLX attention kernels don't support masking yet, so the performance gains are only for decoding on Metal. Once/if they do, I'll update them - otherwise we can explore using Flash Attention kernels for Metal from llama.cpp.