-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash Attention v2 MHA #17227
Merged
Merged
Flash Attention v2 MHA #17227
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h
Fixed
Show fixed
Hide fixed
tianleiwu
reviewed
Aug 18, 2023
tianleiwu
reviewed
Aug 18, 2023
tianleiwu
reviewed
Aug 18, 2023
onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu
Outdated
Show resolved
Hide resolved
tianleiwu
reviewed
Aug 18, 2023
onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu
Outdated
Show resolved
Hide resolved
tianleiwu
reviewed
Aug 18, 2023
onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu
Outdated
Show resolved
Hide resolved
tianleiwu
reviewed
Aug 18, 2023
onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc
Outdated
Show resolved
Hide resolved
tianleiwu
reviewed
Aug 18, 2023
onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc
Outdated
Show resolved
Hide resolved
…nxruntime into flash_v2_packed_mha
…' into flash_v2_packed_mha merge previous commit
tianleiwu
previously approved these changes
Aug 31, 2023
@faxu @pranavsharma, This PR changes ThirdPartyNotices.txt so it requires the admin approval. please take a look. |
yufenglee
approved these changes
Aug 31, 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tianleiwu
approved these changes
Aug 31, 2023
centwang
pushed a commit
that referenced
this pull request
Sep 1, 2023
Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
kleiti
pushed a commit
to kleiti/onnxruntime
that referenced
this pull request
Mar 22, 2024
### Description Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators. Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code. Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator. Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators: * Relative Position Bias is not supported * Different hidden size for Q and V is not supported * Only float16 is supported * Padding/attention mask is not supported * For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention * For Attention, past or present inputs will deactivate flash attention * Causal is not supported Some limitations (like attention mask and causal) might be removed later. Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2. Two environment variables can be used for testing purpose: (1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it. (2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible. ### Speedup The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator. There are 3 input formats: * `Q,K,V` means separated inputs query, key and value of BxSxNH * `Q,KV` means packed KV, where key is 5D: BxSxNx2xH * `QKV` means packed QKV, where query is 5D: BxSxNx3xH Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format. We also notice that, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below). input format | batch size | sequence length | heads | head dim | flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention (TFLOPs/s) -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3 Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7 Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3 Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4 Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8 Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7 Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7 Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3 Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7 Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6 Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2 Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8 Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8 Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5 Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8 Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2 Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2 Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8 Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1 Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6 Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7 Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7 Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3 Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7 Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8 Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1 Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4 Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1 Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6 Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8 Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6 Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5 Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7 Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1 Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3 Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9 Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6 Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2 Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8 Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5 Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6 Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6 Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8 Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8 Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5 Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3 Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8 Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8 Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9 Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0 Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0 Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9 Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9 Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8 QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3 QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9 QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6 QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2 QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9 QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5 QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7 QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2 QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7 QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5 QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2 QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7 QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1 QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7 QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4 QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5 QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8 QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9 QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1 QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6 QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7 QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6 QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5 QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1 QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5 QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2 QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6 QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15 QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84 QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75 QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95 ### Known Issues NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in build. ### Motivation and Context Increases speed and efficiency of MHA or Packed MHA. --------- Co-authored-by: Tianlei Wu <tlwu@microsoft.com> Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators.
Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code.
Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator.
Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators:
Some limitations (like attention mask and causal) might be removed later.
Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2.
Two environment variables can be used for testing purpose:
(1)
ORT_DISABLE_FLASH_ATTENTION
to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it.(2)
ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV
. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible.Speedup
The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator.
There are 3 input formats:
Q,K,V
means separated inputs query, key and value of BxSxNHQ,KV
means packed KV, where key is 5D: BxSxNx2xHQKV
means packed QKV, where query is 5D: BxSxNx3xHNote that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format.
For packed QKV format, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below).
Known Issues
NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory (64GB or above recommended), or use argument like
--nvcc_threads 1
to limit nvcc threads in build.Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.