Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sdpa stride fix #2596

Closed
wants to merge 103 commits into from
Closed

Conversation

EricLBuehler
Copy link
Member

EricLBuehler and others added 30 commits May 15, 2024 15:10
* Offset it

* Freeze

* Offset it

* Offset it

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Remove debugs

* Polish it up

* Polish it up

* Clippy

* Remove test file

* Add config for if neox

* Fix bug

* Fix bug

* Cast cache type on rust side

* Cast types

* To dtype

* Drop temp

* Update casting

* Update casting

* Update casting

* Create dtype in bf16

* Check type

* Debug

* Check dtype

* Check dtype

* Check dtype

* Check dtype

* Check dtype

* Check dtype

* Check dtype

* Check dtype

* Check dtype

* Debug

* Debug

* Debug

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Reseting

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Remove debug

* Debug

* Debug

* Remove debug

* Remove debug

* Debug

* Remove debug

* Debug

* Remove debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Try to use 3dim rotemb fused

* Try to use 3dim rotemb fused

* Remove contig and debug

* Check handling

* Cleanup

* Fix

* Remove prints

* Lower block dim

* Use fused layernorm

* Pass batch size

* Simplify internal API

* Simplify internal API

* Try slow

* Try candle layer norm

* Try candle layer norm

* Fix dep of candle layer norm

* Reshape input for rank 2

* Reshape input for rank 2

* Fix ref

* Code style

* Make dep optional

* Ensure contig

* Ensure contig

* Ensure contig

* Debug contig dmmv error

* Debug contig dmmv error

* Debug contig dmmv error

* Debug contig dmmv error

* Try other method

* Try other method

* Try other method

* Try other method

* Try other method

* Use typestate to optimize

* Use typestate to optimize

* Fixes

* Fixes

* Fixes

* Fixes

* Fixes

* Debug via using slow rmsnorm

* Debug via using slow rope

* Remove debug

* More debugging

* Remove debug

* Remove debug

* Remove debug

* Add better error enum

* Fix diff marker

* Fix some things

* Fix some things

* Fix some things

* Fix dummy backends

* Re add from storage noop

* Fix removed kvconcat custom op

* Fix erroneous feature gate

* Complete metal backend refactoring

* Check if calling

* Check if calling

* Update default for force dmmv

* Load atomic

* Debug

* Use mmvq

* Update

* Add the empty functions

* Add rope new_partial function

* Make variant of qmatmul pub

* Make variant of qmatmul pub

* Add the varbuilder set_device function

* Only link stdc++ if target has msvc

* Only link stdc++ if target has msvc

* Only link stdc++ if target has msvc

* Only link stdc++ if target has msvc

* Handle case of device mapping

* Handle case of device mapping

* Add getter

* Fix

* Fix

* Support nvcc flags in flash attn

* Support nvcc flags in flash attn

* Support nvcc flags in flash attn

* Support nvcc flags in flash attn

* Support nvcc flags in flash attn

* Fixes

* Fixes

* Fix the tests

* Fix the tests
* Support flash-attn in quantized phi3. (huggingface#2194)

* Use flash-attn in gemma. (huggingface#2195)

* Use flash-attn in gemma.

* Fix flash-attn for head dim 256.

* Remove candle-layer-norm

---------

Co-authored-by: Laurent Mazare <[email protected]>
* Add unfold

* Format
* Add the quantize_onto api

* Take ref

* Clippy

* Format

* Add error checking
* Use flash-attn in gemma.

* Fix for the fast bf16 cublas gemm.

* Fix some clippy lints.

* Fix another lint.

* Proper clippy fix.
* define structs

* construct ResidualConvUnit

* forward() for ResidualConvUnit

* implement FeatureFusionBlock

* implement Scratch

* implement DPTHead

* add identity module

* implement forward for DTPHead

* add get_intermediate_layers to DinoVisionTransformer

* implement DepthAnythingV2

* some minor tweaks

* fix compile errors

* fix var builder prefixes

* setup initial example

* use fixed patch size of 37 (518 / 14)

* debugged until output

* print min and max values

* add some dynamism to the output location

* scale input image

* extract prep function

* extract output path function

* normalize image with magic mean and std

* add spectral coloring

* squeeze in the right place

* make enterpolation optional

* use bail instead of panic

* omit unnecessary Shape call

* remove empty curly braces

* use bail instead of assert

* use vb and pp

* remove closures

* extract config object

* Apply rustfmt.

* Fix some clippy lints.

* More lints.

* Use the array methods.

---------

Co-authored-by: laurent <[email protected]>
* feat(gemm): implement Gemm operator in candle-onnx

* feat(onnx): Add support for ArgMax operator in candle-onnx

* Apply rustfmt.

* Remove argmax as it was already present.

---------

Co-authored-by: Laurent <[email protected]>
* Add: DINOv2Reg4 with PlantCLEF2024 weights and example ( See https://arxiv.org/abs/2309.16588 and https://zenodo.org/records/10848263 )

* Remove extra files + update README to download them + remove extra lines

* minor fix (README remove extra spaces)

* minor fix (README: Fix image url)

* Modif: Add back interpolate_pos_encoding() + fix when no interpolation + remove extra comments + Update README ( source image changed and so the predictions )

* Fix: Improve code lisibility with '$ cargo clippy' and '$ cargo fmt'

* Another clippy fix.

---------

Co-authored-by: x-VEspit <[email protected]>
Co-authored-by: laurent <[email protected]>
janimo and others added 26 commits September 6, 2024 14:37
* Update cudarc to 0.12.

* Some cudnn tweaks.
* correct optional SE layer dimensions.
 * head_dim instead of num_heads is 32.
 * update test example output.
* Allow loading images with given std and mean

* OpenCLIP text encoder component

* Two MobileCLIP models

* Clippy fixes.

---------

Co-authored-by: Laurent <[email protected]>
* fix FLUX.1 weights

* added flux1-dev.safetensors
* Clippy fixes for 1.81.0.

* Another fix.
* Bump the version to 0.6.1. (huggingface#2438)

* onnx: workaround pow with negative base (huggingface#2439)

* onnx: workaround pow with negative base

rather than fully defining pow in the cpu backend (as in huggingface#2318),
this implements a much smaller change which is sufficient to evaluate silero-vad
onnx models. Specifically, checking if pow is run with 2.0 exponent, and if so
evaluate as simply `x*x` instead of the cpu backend of `e^(2.0 * ln(x))`.

* PR: use Tensor::powf insead

powf correctly handles a negative base.

* onnx: support negative index in Gather (huggingface#2440)

index_select does not support negative indexing, but
this change adds just enough workarounds in onnx to
allow evaluating silero-vad models (which make use of
negative indices).

* silero-vad v5 example (huggingface#2321)

* silero-vad v5 example

This change adds an example of how to run silero-vad v5

* PR: rename 'vad' to 'silero-vad'

* Update README.md

---------

Co-authored-by: Laurent Mazare <[email protected]>

* Fix for parler-tts, do not add the last slice of padding tokens. (huggingface#2442)

* Fix for parler-tts, do not add the last slice of padding tokens.

* Support for the mini model.

* Add FastViT model. (huggingface#2444)

* fix: qwen2 lm_head loading huggingface#2443 (huggingface#2445)

Co-authored-by: Yi Xu <[email protected]>

* Update cudarc to 0.12. (huggingface#2451)

* Update cudarc to 0.12.

* Some cudnn tweaks.

* FastViT fixes. (huggingface#2452)

* correct optional SE layer dimensions.
 * head_dim instead of num_heads is 32.
 * update test example output.

* MobileCLIP models S1 and S2 (huggingface#2454)

* Allow loading images with given std and mean

* OpenCLIP text encoder component

* Two MobileCLIP models

* Clippy fixes.

---------

Co-authored-by: Laurent <[email protected]>

* Fix FLUX.1 weights (huggingface#2457)

* fix FLUX.1 weights

* added flux1-dev.safetensors

* Clippy fixes for 1.81.0. (huggingface#2461)

* Clippy fixes for 1.81.0.

* Another fix.

* Make Error::msg more in line with anyhow::Error::msg

* Add context trait

* Even more flexible

* Format

---------

Co-authored-by: Laurent Mazare <[email protected]>
Co-authored-by: shua <[email protected]>
Co-authored-by: Jani Monoses <[email protected]>
Co-authored-by: ilookee <[email protected]>
Co-authored-by: Yi Xu <[email protected]>
Co-authored-by: Eugene Hauptmann <[email protected]>
* Add api to get current seed

* Remove cell for rwlock
* Add the i16 dtype

* Added I16 and I32 to fix the missing arms issue (candle-onnx/eval)

* Update rust-ci.yml

* Update ci_cuda.yaml

* fmt adjustment

* Revert "Update rust-ci.yml"

This reverts commit f659d36.

* Revert "Update ci_cuda.yaml"

This reverts commit 62a4b39.
* Add initial f8 e4m3 type

* Fixes

* Update deps

* Implement CudaDType

* Add some cast kernels

* Add copy2d, other impls

* Fix zeros impl

* Error checking for metal

* Use isnanf
* Sketch the sdpa kernel

* Add full sdpa kernel,

* Add test

* Add vectorized kernel for decoding

* Update tests

* Add some docs

* Fix sdpa_vector names

* Add softcapping for vectorized sdpa

* Add softcapping for full sdpa

* Add support for head dim 32, 96, 256

* Add support for head dim 32, 96, 256

* Update docs

* Add update notice

* Clippy and format
@EricLBuehler EricLBuehler deleted the sdpa_v_stride_fix branch November 4, 2024 16:36
@EricLBuehler EricLBuehler restored the sdpa_v_stride_fix branch November 4, 2024 16:36
@EricLBuehler EricLBuehler deleted the sdpa_v_stride_fix branch November 10, 2024 03:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.