-
Notifications
You must be signed in to change notification settings - Fork 998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sdpa stride fix #2596
Closed
Closed
Sdpa stride fix #2596
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
* Offset it * Freeze * Offset it * Offset it * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Try out vllm impl again * Remove debugs * Polish it up * Polish it up * Clippy * Remove test file * Add config for if neox * Fix bug * Fix bug * Cast cache type on rust side * Cast types * To dtype * Drop temp * Update casting * Update casting * Update casting * Create dtype in bf16 * Check type * Debug * Check dtype * Check dtype * Check dtype * Check dtype * Check dtype * Check dtype * Check dtype * Check dtype * Check dtype * Debug * Debug * Debug * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Check old method * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Use mistral slow rope impl * Reseting * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Remove debug * Debug * Debug * Remove debug * Remove debug * Debug * Remove debug * Debug * Remove debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Debug * Try to use 3dim rotemb fused * Try to use 3dim rotemb fused * Remove contig and debug * Check handling * Cleanup * Fix * Remove prints * Lower block dim * Use fused layernorm * Pass batch size * Simplify internal API * Simplify internal API * Try slow * Try candle layer norm * Try candle layer norm * Fix dep of candle layer norm * Reshape input for rank 2 * Reshape input for rank 2 * Fix ref * Code style * Make dep optional * Ensure contig * Ensure contig * Ensure contig * Debug contig dmmv error * Debug contig dmmv error * Debug contig dmmv error * Debug contig dmmv error * Try other method * Try other method * Try other method * Try other method * Try other method * Use typestate to optimize * Use typestate to optimize * Fixes * Fixes * Fixes * Fixes * Fixes * Debug via using slow rmsnorm * Debug via using slow rope * Remove debug * More debugging * Remove debug * Remove debug * Remove debug * Add better error enum * Fix diff marker * Fix some things * Fix some things * Fix some things * Fix dummy backends * Re add from storage noop * Fix removed kvconcat custom op * Fix erroneous feature gate * Complete metal backend refactoring * Check if calling * Check if calling * Update default for force dmmv * Load atomic * Debug * Use mmvq * Update * Add the empty functions * Add rope new_partial function * Make variant of qmatmul pub * Make variant of qmatmul pub * Add the varbuilder set_device function * Only link stdc++ if target has msvc * Only link stdc++ if target has msvc * Only link stdc++ if target has msvc * Only link stdc++ if target has msvc * Handle case of device mapping * Handle case of device mapping * Add getter * Fix * Fix * Support nvcc flags in flash attn * Support nvcc flags in flash attn * Support nvcc flags in flash attn * Support nvcc flags in flash attn * Support nvcc flags in flash attn * Fixes * Fixes * Fix the tests * Fix the tests
* Support flash-attn in quantized phi3. (huggingface#2194) * Use flash-attn in gemma. (huggingface#2195) * Use flash-attn in gemma. * Fix flash-attn for head dim 256. * Remove candle-layer-norm --------- Co-authored-by: Laurent Mazare <[email protected]>
* Add unfold * Format
* Add the quantize_onto api * Take ref * Clippy * Format * Add error checking
* Use flash-attn in gemma. * Fix for the fast bf16 cublas gemm. * Fix some clippy lints. * Fix another lint. * Proper clippy fix.
* define structs * construct ResidualConvUnit * forward() for ResidualConvUnit * implement FeatureFusionBlock * implement Scratch * implement DPTHead * add identity module * implement forward for DTPHead * add get_intermediate_layers to DinoVisionTransformer * implement DepthAnythingV2 * some minor tweaks * fix compile errors * fix var builder prefixes * setup initial example * use fixed patch size of 37 (518 / 14) * debugged until output * print min and max values * add some dynamism to the output location * scale input image * extract prep function * extract output path function * normalize image with magic mean and std * add spectral coloring * squeeze in the right place * make enterpolation optional * use bail instead of panic * omit unnecessary Shape call * remove empty curly braces * use bail instead of assert * use vb and pp * remove closures * extract config object * Apply rustfmt. * Fix some clippy lints. * More lints. * Use the array methods. --------- Co-authored-by: laurent <[email protected]>
* feat(gemm): implement Gemm operator in candle-onnx * feat(onnx): Add support for ArgMax operator in candle-onnx * Apply rustfmt. * Remove argmax as it was already present. --------- Co-authored-by: Laurent <[email protected]>
* Add: DINOv2Reg4 with PlantCLEF2024 weights and example ( See https://arxiv.org/abs/2309.16588 and https://zenodo.org/records/10848263 ) * Remove extra files + update README to download them + remove extra lines * minor fix (README remove extra spaces) * minor fix (README: Fix image url) * Modif: Add back interpolate_pos_encoding() + fix when no interpolation + remove extra comments + Update README ( source image changed and so the predictions ) * Fix: Improve code lisibility with '$ cargo clippy' and '$ cargo fmt' * Another clippy fix. --------- Co-authored-by: x-VEspit <[email protected]> Co-authored-by: laurent <[email protected]>
Co-authored-by: Yi Xu <[email protected]>
* Update cudarc to 0.12. * Some cudnn tweaks.
* correct optional SE layer dimensions. * head_dim instead of num_heads is 32. * update test example output.
* Allow loading images with given std and mean * OpenCLIP text encoder component * Two MobileCLIP models * Clippy fixes. --------- Co-authored-by: Laurent <[email protected]>
* fix FLUX.1 weights * added flux1-dev.safetensors
* Clippy fixes for 1.81.0. * Another fix.
* Bump the version to 0.6.1. (huggingface#2438) * onnx: workaround pow with negative base (huggingface#2439) * onnx: workaround pow with negative base rather than fully defining pow in the cpu backend (as in huggingface#2318), this implements a much smaller change which is sufficient to evaluate silero-vad onnx models. Specifically, checking if pow is run with 2.0 exponent, and if so evaluate as simply `x*x` instead of the cpu backend of `e^(2.0 * ln(x))`. * PR: use Tensor::powf insead powf correctly handles a negative base. * onnx: support negative index in Gather (huggingface#2440) index_select does not support negative indexing, but this change adds just enough workarounds in onnx to allow evaluating silero-vad models (which make use of negative indices). * silero-vad v5 example (huggingface#2321) * silero-vad v5 example This change adds an example of how to run silero-vad v5 * PR: rename 'vad' to 'silero-vad' * Update README.md --------- Co-authored-by: Laurent Mazare <[email protected]> * Fix for parler-tts, do not add the last slice of padding tokens. (huggingface#2442) * Fix for parler-tts, do not add the last slice of padding tokens. * Support for the mini model. * Add FastViT model. (huggingface#2444) * fix: qwen2 lm_head loading huggingface#2443 (huggingface#2445) Co-authored-by: Yi Xu <[email protected]> * Update cudarc to 0.12. (huggingface#2451) * Update cudarc to 0.12. * Some cudnn tweaks. * FastViT fixes. (huggingface#2452) * correct optional SE layer dimensions. * head_dim instead of num_heads is 32. * update test example output. * MobileCLIP models S1 and S2 (huggingface#2454) * Allow loading images with given std and mean * OpenCLIP text encoder component * Two MobileCLIP models * Clippy fixes. --------- Co-authored-by: Laurent <[email protected]> * Fix FLUX.1 weights (huggingface#2457) * fix FLUX.1 weights * added flux1-dev.safetensors * Clippy fixes for 1.81.0. (huggingface#2461) * Clippy fixes for 1.81.0. * Another fix. * Make Error::msg more in line with anyhow::Error::msg * Add context trait * Even more flexible * Format --------- Co-authored-by: Laurent Mazare <[email protected]> Co-authored-by: shua <[email protected]> Co-authored-by: Jani Monoses <[email protected]> Co-authored-by: ilookee <[email protected]> Co-authored-by: Yi Xu <[email protected]> Co-authored-by: Eugene Hauptmann <[email protected]>
* Add api to get current seed * Remove cell for rwlock
* Add initial f8 e4m3 type * Fixes * Update deps * Implement CudaDType * Add some cast kernels * Add copy2d, other impls * Fix zeros impl * Error checking for metal * Use isnanf
* Sketch the sdpa kernel * Add full sdpa kernel, * Add test * Add vectorized kernel for decoding * Update tests * Add some docs * Fix sdpa_vector names * Add softcapping for vectorized sdpa * Add softcapping for full sdpa * Add support for head dim 32, 96, 256 * Add support for head dim 32, 96, 256 * Update docs * Add update notice * Clippy and format
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Refs ml-explore/mlx#1558