Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add softcapping support to flash attention #2437

Closed
wants to merge 72 commits into from

Conversation

EricLBuehler
Copy link
Member

No description provided.

EricLBuehler and others added 30 commits May 15, 2024 15:10
* Offset it

* Freeze

* Offset it

* Offset it

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Try out vllm impl again

* Remove debugs

* Polish it up

* Polish it up

* Clippy

* Remove test file

* Add config for if neox

* Fix bug

* Fix bug

* Cast cache type on rust side

* Cast types

* To dtype

* Drop temp

* Update casting

* Update casting

* Update casting

* Create dtype in bf16

* Check type

* Debug

* Check dtype

* Check dtype

* Check dtype

* Check dtype

* Check dtype

* Check dtype

* Check dtype

* Check dtype

* Check dtype

* Debug

* Debug

* Debug

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Check old method

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Use mistral slow rope impl

* Reseting

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Remove debug

* Debug

* Debug

* Remove debug

* Remove debug

* Debug

* Remove debug

* Debug

* Remove debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Debug

* Try to use 3dim rotemb fused

* Try to use 3dim rotemb fused

* Remove contig and debug

* Check handling

* Cleanup

* Fix

* Remove prints

* Lower block dim

* Use fused layernorm

* Pass batch size

* Simplify internal API

* Simplify internal API

* Try slow

* Try candle layer norm

* Try candle layer norm

* Fix dep of candle layer norm

* Reshape input for rank 2

* Reshape input for rank 2

* Fix ref

* Code style

* Make dep optional

* Ensure contig

* Ensure contig

* Ensure contig

* Debug contig dmmv error

* Debug contig dmmv error

* Debug contig dmmv error

* Debug contig dmmv error

* Try other method

* Try other method

* Try other method

* Try other method

* Try other method

* Use typestate to optimize

* Use typestate to optimize

* Fixes

* Fixes

* Fixes

* Fixes

* Fixes

* Debug via using slow rmsnorm

* Debug via using slow rope

* Remove debug

* More debugging

* Remove debug

* Remove debug

* Remove debug

* Add better error enum

* Fix diff marker

* Fix some things

* Fix some things

* Fix some things

* Fix dummy backends

* Re add from storage noop

* Fix removed kvconcat custom op

* Fix erroneous feature gate

* Complete metal backend refactoring

* Check if calling

* Check if calling

* Update default for force dmmv

* Load atomic

* Debug

* Use mmvq

* Update

* Add the empty functions

* Add rope new_partial function

* Make variant of qmatmul pub

* Make variant of qmatmul pub

* Add the varbuilder set_device function

* Only link stdc++ if target has msvc

* Only link stdc++ if target has msvc

* Only link stdc++ if target has msvc

* Only link stdc++ if target has msvc

* Handle case of device mapping

* Handle case of device mapping

* Add getter

* Fix

* Fix

* Support nvcc flags in flash attn

* Support nvcc flags in flash attn

* Support nvcc flags in flash attn

* Support nvcc flags in flash attn

* Support nvcc flags in flash attn

* Fixes

* Fixes

* Fix the tests

* Fix the tests
* Support flash-attn in quantized phi3. (huggingface#2194)

* Use flash-attn in gemma. (huggingface#2195)

* Use flash-attn in gemma.

* Fix flash-attn for head dim 256.

* Remove candle-layer-norm

---------

Co-authored-by: Laurent Mazare <[email protected]>
* Add unfold

* Format
* Add the quantize_onto api

* Take ref

* Clippy

* Format

* Add error checking
* Use flash-attn in gemma.

* Fix for the fast bf16 cublas gemm.

* Fix some clippy lints.

* Fix another lint.

* Proper clippy fix.
* define structs

* construct ResidualConvUnit

* forward() for ResidualConvUnit

* implement FeatureFusionBlock

* implement Scratch

* implement DPTHead

* add identity module

* implement forward for DTPHead

* add get_intermediate_layers to DinoVisionTransformer

* implement DepthAnythingV2

* some minor tweaks

* fix compile errors

* fix var builder prefixes

* setup initial example

* use fixed patch size of 37 (518 / 14)

* debugged until output

* print min and max values

* add some dynamism to the output location

* scale input image

* extract prep function

* extract output path function

* normalize image with magic mean and std

* add spectral coloring

* squeeze in the right place

* make enterpolation optional

* use bail instead of panic

* omit unnecessary Shape call

* remove empty curly braces

* use bail instead of assert

* use vb and pp

* remove closures

* extract config object

* Apply rustfmt.

* Fix some clippy lints.

* More lints.

* Use the array methods.

---------

Co-authored-by: laurent <[email protected]>
* feat(gemm): implement Gemm operator in candle-onnx

* feat(onnx): Add support for ArgMax operator in candle-onnx

* Apply rustfmt.

* Remove argmax as it was already present.

---------

Co-authored-by: Laurent <[email protected]>
* Add: DINOv2Reg4 with PlantCLEF2024 weights and example ( See https://arxiv.org/abs/2309.16588 and https://zenodo.org/records/10848263 )

* Remove extra files + update README to download them + remove extra lines

* minor fix (README remove extra spaces)

* minor fix (README: Fix image url)

* Modif: Add back interpolate_pos_encoding() + fix when no interpolation + remove extra comments + Update README ( source image changed and so the predictions )

* Fix: Improve code lisibility with '$ cargo clippy' and '$ cargo fmt'

* Another clippy fix.

---------

Co-authored-by: x-VEspit <[email protected]>
Co-authored-by: laurent <[email protected]>
EricLBuehler and others added 27 commits August 7, 2024 17:06
* Add i32 dtype for cpu and cuda, with kernels

* Fix cuda i32

* Fix cpu i32

* Add cuda map impls for i32

* Start to add to metal

* Add the kernels

* Oops

* Fix dtype cast in safetensors

* Oops

* Oops

* Add bf16 to i32 and vice versa casts
* Add the flux autoencoder.

* Add the encoder down-blocks.

* Upsampling in the decoder.

* Sketch the flow matching model.

* More flux model.

* Add some of the positional embeddings.

* Add the rope embeddings.

* Add the sampling functions.

* Add the flux example.

* Fix the T5 bits.

* Proper T5 tokenizer.

* Clip encoder path fix.

* Get the clip embeddings.

* No configurable weights in layer norm.

* More weights related fixes.

* Yet another shape fix.

* DType fix.

* Fix a couple more shape issues.

* DType fixes.

* Fix the latent dims.

* Fix more shape issues.

* Autoencoder fixes.

* Get some generations out.

* Bugfix.

* T5 padding.

* Clippy fix.

* Add the decode only mode.

* Fix.

* More fixes.

* Finally get some generations to work.

* Add readme.
* add models support and example for THUDM/glm-4

* fix the ci report

* fmt

* fix

* Update README.org

* Update README.org

* fmt

* Update README.org

* README.md add codegeex4

* README.md add glm4

* Typo.

* change expect into ?

---------

Co-authored-by: Laurent Mazare <[email protected]>
* add mmdit of stable diffusion 3

lint

add comments

* correct a misplaced comment

* fix cargo fmt

* fix clippy error

* use bail! instead of assert!

* use get_on_dim in splitting qkv
* chore: changes from formatting on save

* fix: usage of `actions/checkout@v2`
Also squeeze the first dimension of the codes tensor in the example file to get the expected three dimensions.
* Soft NMS with thresholds

* NMS Test

* Soft nms w/ boxes removed below threshold

* Soft nms test

* No longer removing bounding boxes to fit Soft-NMS focus

* Initialize confidence

* Added comments

* Refactored out updating based on IOU/sigma

* Score_threshold -> confidence_threshold for clarity

* Remove bboxes below confidence threshold

* Softnms basic functionality test

* Softnms confidence decay test

* Softnms confidence threshold test

* Softnms no overlapping bbox test

* Testing confidence after no overlap test

* Single bbox and no bbox tests

* Signify test completion

* Handling result of test functions

* Checking all pairs of bboxes instead of a forward pass

* Equal confidence overlap test

* Clarified tests for implementation

* No longer dropping boxes, just setting to 0.0

* Formatted w/ cargo
…ds (huggingface#2308)

* Add documentation examples for `Tensor` methods

* Apply fmt.

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <[email protected]>
* Clippy fixes.

* Bump the web_sys required version.
* Add GGUF bf16 type support

* Add non avx impl for vec_dot_bf16

* Fix from_u32

* Fix loading

* Fix dequant of bf16
@EricLBuehler EricLBuehler deleted the flash_attn_softcap branch August 22, 2024 01:43
@EricLBuehler EricLBuehler restored the flash_attn_softcap branch August 22, 2024 01:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.