Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash Attention v2 MHA #17227

Merged
merged 70 commits into from
Aug 31, 2023
Merged

Flash Attention v2 MHA #17227

merged 70 commits into from
Aug 31, 2023

Conversation

aciddelgado
Copy link
Contributor

@aciddelgado aciddelgado commented Aug 18, 2023

Description

Integrate Flash Attention V2 to PackedMultiHeadAttention, MultiHeadAttention and Attention operators.

Flash Attention v2 source code is from https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src. We did some change to remove dependency on Torch, then removed backward and bfloat16 related code.

Add benchmark script (see benchmark_mha.sh) to compare different attention kernels for MultiHeadAttention operator.

Current limitations for Flash Attention in PackedMultiHeadAttention, MultiHeadAttention and Attention operators:

  • Relative Position Bias is not supported
  • Different hidden size for Q and V is not supported
  • Only float16 is supported
  • Padding/attention mask is not supported
  • For MultiHeadAttention, when there is past or present input, bias shall be provided to activate flash attention
  • For Attention, past or present inputs will deactivate flash attention
  • Causal is not supported

Some limitations (like attention mask and causal) might be removed later.

Currently, Flash Attention v2 only works in Linux. For Windows, we will enable later with Cutlass 3.2.

Two environment variables can be used for testing purpose:
(1) ORT_DISABLE_FLASH_ATTENTION to disable flash attention. Default value is 0 (enable). Set it to "1" to disable it.
(2) ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV. Default value is "513", which means that we only enable flash attention when sequence length is larger than 512 for packed QKV format. Set it to "0" if you want to use flash attention v2 whenever possible.

Speedup

The following result is from Standard_ND96amsr_A100_v4 VM (A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per second for MultiHeadAttention operator.

There are 3 input formats:

  • Q,K,V means separated inputs query, key and value of BxSxNH
  • Q,KV means packed KV, where key is 5D: BxSxNx2xH
  • QKV means packed QKV, where query is 5D: BxSxNx3xH

Note that flash attention cannot use packed QKV format, so extra Transpose is needed. We found that TensorRT kernel is faster for sequence length <= 512 for packed QKV. The reason might be no transpose is needed for TensorRT kernel in this format.

For packed QKV format, TensorRT kernel is faster for stable diffusion 512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while flash attention v2 is faster for 1024x1024 image (see seq_len=16384, heads=8, head_dim=40 below).

input format batch size sequence length heads head dim flash_v2 (TFLOPs/s) TensorRT (TFLOPs/s) Memory Efficient Attention (TFLOPs/s)
Q,K,V 32 512 64 32 78.1 60.0 39.3
Q,K,V 32 512 128 16 46.8 44.1 21.7
Q,K,V 16 1024 64 32 99.0 72.8 44.3
Q,K,V 16 1024 128 16 54.7 49.2 23.4
Q,K,V 8 2048 64 32 113.8 81.2 47.8
Q,K,V 8 2048 128 16 59.7 51.9 24.7
Q,K,V 4 4096 64 32 122.5 85.6 49.7
Q,K,V 4 4096 128 16 62.5 53.3 25.3
Q,K,V 2 8192 64 32 127.4 87.5 50.7
Q,K,V 2 8192 128 16 64.0 54.2 25.6
Q,K,V 1 16384 64 32 129.5 91.0 51.2
Q,K,V 1 16384 128 16 64.7 54.5 25.8
Q,K,V 1 4096 8 40 51.0 43.6 36.8
Q,K,V 1 4096 8 80 97.7 77.0 55.5
Q,K,V 1 4096 8 160 120.0 39.7 57.8
Q,K,V 4 4096 8 40 89.0 84.4 49.2
Q,K,V 4 4096 8 80 133.0 92.2 63.2
Q,K,V 4 4096 8 160 164.8 42.7 63.8
Q,K,V 1 16384 8 40 96.9 91.3 52.1
Q,K,V 1 16384 8 80 142.9 101.5 65.6
Q,K,V 1 16384 8 160 177.4 44.2 65.7
Q,K,V 128 128 12 64 29.0 26.9 25.7
Q,K,V 64 128 12 64 23.1 10.8 21.3
Q,K,V 128 384 12 64 83.5 60.8 55.7
Q,K,V 64 384 12 64 72.6 40.5 52.8
Q,K,V 128 512 12 64 98.9 77.9 62.1
Q,K,V 64 512 12 64 94.7 75.6 60.4
Q,KV 32 512 64 32 85.9 41.1 41.1
Q,KV 32 512 128 16 47.1 21.6 21.6
Q,KV 16 1024 64 32 104.4 45.8 45.8
Q,KV 16 1024 128 16 54.7 23.6 23.6
Q,KV 8 2048 64 32 116.8 48.5 48.5
Q,KV 8 2048 128 16 59.8 24.7 24.7
Q,KV 4 4096 64 32 124.2 50.1 50.1
Q,KV 4 4096 128 16 62.6 25.3 25.3
Q,KV 2 8192 64 32 128.5 50.8 50.9
Q,KV 2 8192 128 16 64.1 25.6 25.6
Q,KV 1 16384 64 32 129.4 51.2 51.2
Q,KV 1 16384 128 16 64.8 25.8 25.8
Q,KV 1 4096 8 40 67.5 37.7 37.5
Q,KV 1 4096 8 80 101.3 56.7 56.6
Q,KV 1 4096 8 160 124.0 58.6 58.6
Q,KV 4 4096 8 40 90.8 49.8 49.8
Q,KV 4 4096 8 80 135.6 63.8 63.8
Q,KV 4 4096 8 160 166.3 64.5 64.5
Q,KV 1 16384 8 40 97.5 52.3 52.3
Q,KV 1 16384 8 80 143.5 65.9 65.8
Q,KV 1 16384 8 160 178.4 65.9 65.8
Q,KV 128 128 12 64 26.8 48.1 30.9
Q,KV 64 128 12 64 28.0 38.9 25.0
Q,KV 128 384 12 64 97.7 61.1 61.0
Q,KV 64 384 12 64 89.5 57.8 57.9
Q,KV 128 512 12 64 111.9 66.7 66.9
Q,KV 64 512 12 64 107.2 64.9 64.8
QKV 32 512 64 32 77.2 84.7 39.3
QKV 32 512 128 16 43.4 53.1 20.9
QKV 16 1024 64 32 98.8 87.4 44.6
QKV 16 1024 128 16 52.0 54.1 23.2
QKV 8 2048 64 32 113.1 89.0 47.9
QKV 8 2048 128 16 58.2 54.6 24.5
QKV 4 4096 64 32 120.6 89.7 49.7
QKV 4 4096 128 16 61.7 54.6 25.2
QKV 2 8192 64 32 125.9 89.5 50.7
QKV 2 8192 128 16 63.6 54.8 25.5
QKV 1 16384 64 32 128.5 92.0 51.2
QKV 1 16384 128 16 64.6 54.8 25.7
QKV 1 4096 8 40 60.2 69.8 38.1
QKV 1 4096 8 80 101.6 75.2 56.7
QKV 1 4096 8 160 130.2 41.2 58.4
QKV 4 4096 8 40 90.6 91.0 49.5
QKV 4 4096 8 80 133.6 98.1 62.8
QKV 4 4096 8 160 165.3 43.7 63.9
QKV 1 16384 8 40 97.2 92.8 52.1
QKV 1 16384 8 80 143.0 103.1 65.6
QKV 1 16384 8 160 177.6 44.5 65.7
QKV 128 128 12 64 31.1 65.9 27.6
QKV 64 128 12 64 26.1 49.8 23.5
QKV 128 384 12 64 84.6 88.5 56.1
QKV 64 384 12 64 79.1 80.3 53.5
QKV 128 512 12 64 97.3 114.2 62.2
QKV 64 512 12 64 95.9 110.7 60.6
QKV 4 2048 32 128 125.26 44.72 78.15
QKV 4 4096 32 128 141.62 46.29 85.84
QKV 8 2048 32 128 127.40 45.49 78.75
QKV 8 4096 32 128 144.24 46.60 86.95

Known Issues

NVCC uses huge memory while compiling flash attention CUDA kernel. Linux build with CUDA might fail when machine has limited memory while number of CPUs is large. Walkaround is to use a build machine with larger memory (64GB or above recommended), or use argument like --nvcc_threads 1 to limit nvcc threads in build.

Motivation and Context

Increases speed and efficiency of MHA or Packed MHA.

cmake/CMakeSettings.json Outdated Show resolved Hide resolved
@tianleiwu tianleiwu requested a review from snnn August 31, 2023 04:02
tianleiwu
tianleiwu previously approved these changes Aug 31, 2023
@tianleiwu
Copy link
Contributor

@faxu @pranavsharma, This PR changes ThirdPartyNotices.txt so it requires the admin approval. please take a look.

Copy link
Member

@yufenglee yufenglee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

@aciddelgado aciddelgado merged commit 44101e8 into main Aug 31, 2023
93 of 95 checks passed
@aciddelgado aciddelgado deleted the flash_v2_packed_mha branch August 31, 2023 20:52
@natke natke added the triage:approved Approved for cherrypicks for release label Sep 1, 2023
centwang pushed a commit that referenced this pull request Sep 1, 2023
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.

Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.

Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.

Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported

Some limitations (like attention mask and causal) might be removed
later.

Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.

Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.

The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.

There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH

Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.

We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).

input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95

NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.

Increases speed and efficiency of MHA or Packed MHA.

---------

Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
snnn pushed a commit that referenced this pull request Sep 7, 2023
Cherry-pick 2nd round for 1.16.0 release.
PR List:

#17201
#17270
#17311
#17315
#17320
#17326
#17355
#17227
#17380
#17386
centwang pushed a commit that referenced this pull request Sep 8, 2023
…17461)

### Description
Remove 52 from CMAKE_CUDA_ARCHITECTURES to reduce Nuget package size. 

### Motivation and Context
PR #17227 increased binary size by 20%. Right the package size is about
260MB. However, nuget has a hard limit of 250MB. Without this change we
cannot publish the package.
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
### Description
Integrate Flash Attention V2 to PackedMultiHeadAttention,
MultiHeadAttention and Attention operators.

Flash Attention v2 source code is from
https://github.com/Dao-AILab/flash-attention/tree/main/csrc/flash_attn/src.
We did some change to remove dependency on Torch, then removed backward
and bfloat16 related code.

Add benchmark script (see benchmark_mha.sh) to compare different
attention kernels for MultiHeadAttention operator.

Current limitations for Flash Attention in PackedMultiHeadAttention,
MultiHeadAttention and Attention operators:
* Relative Position Bias is not supported
* Different hidden size for Q and V is not supported
* Only float16 is supported
* Padding/attention mask is not supported
* For MultiHeadAttention, when there is past or present input, bias
shall be provided to activate flash attention
* For Attention, past or present inputs will deactivate flash attention
* Causal is not supported

Some limitations (like attention mask and causal) might be removed
later.

Currently, Flash Attention v2 only works in Linux. For Windows, we will
enable later with Cutlass 3.2.

Two environment variables can be used for testing purpose:
(1) `ORT_DISABLE_FLASH_ATTENTION` to disable flash attention. Default
value is 0 (enable). Set it to "1" to disable it.
(2) `ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV`. Default value is
"513", which means that we only enable flash attention when sequence
length is larger than 512 for packed QKV format. Set it to "0" if you
want to use flash attention v2 whenever possible.

### Speedup

The following result is from Standard_ND96amsr_A100_v4 VM
(A100-SXM4-80GB GPU) using benchmark_mha.sh. The metric is TFLOPs per
second for MultiHeadAttention operator.

There are 3 input formats:
* `Q,K,V` means separated inputs query, key and value of BxSxNH
* `Q,KV` means packed KV, where key is 5D: BxSxNx2xH
* `QKV` means packed QKV, where query is 5D: BxSxNx3xH

Note that flash attention cannot use packed QKV format, so extra
Transpose is needed. We found that TensorRT kernel is faster for
sequence length <= 512 for packed QKV. The reason might be no transpose
is needed for TensorRT kernel in this format.

We also notice that, TensorRT kernel is faster for stable diffusion
512x512 image (see seq_len=4096, heads=8, head_dim=40 below), while
flash attention v2 is faster for 1024x1024 image (see seq_len=16384,
heads=8, head_dim=40 below).

input format | batch size | sequence length | heads | head dim |
flash_v2 (TFLOPs/s) | TensorRT (TFLOPs/s) | Memory Efficient Attention
(TFLOPs/s)
-- | -- | -- | -- | -- | -- | -- | --
Q,K,V | 32 | 512 | 64 | 32 | 78.1 | 60.0 | 39.3
Q,K,V | 32 | 512 | 128 | 16 | 46.8 | 44.1 | 21.7
Q,K,V | 16 | 1024 | 64 | 32 | 99.0 | 72.8 | 44.3
Q,K,V | 16 | 1024 | 128 | 16 | 54.7 | 49.2 | 23.4
Q,K,V | 8 | 2048 | 64 | 32 | 113.8 | 81.2 | 47.8
Q,K,V | 8 | 2048 | 128 | 16 | 59.7 | 51.9 | 24.7
Q,K,V | 4 | 4096 | 64 | 32 | 122.5 | 85.6 | 49.7
Q,K,V | 4 | 4096 | 128 | 16 | 62.5 | 53.3 | 25.3
Q,K,V | 2 | 8192 | 64 | 32 | 127.4 | 87.5 | 50.7
Q,K,V | 2 | 8192 | 128 | 16 | 64.0 | 54.2 | 25.6
Q,K,V | 1 | 16384 | 64 | 32 | 129.5 | 91.0 | 51.2
Q,K,V | 1 | 16384 | 128 | 16 | 64.7 | 54.5 | 25.8
Q,K,V | 1 | 4096 | 8 | 40 | 51.0 | 43.6 | 36.8
Q,K,V | 1 | 4096 | 8 | 80 | 97.7 | 77.0 | 55.5
Q,K,V | 1 | 4096 | 8 | 160 | 120.0 | 39.7 | 57.8
Q,K,V | 4 | 4096 | 8 | 40 | 89.0 | 84.4 | 49.2
Q,K,V | 4 | 4096 | 8 | 80 | 133.0 | 92.2 | 63.2
Q,K,V | 4 | 4096 | 8 | 160 | 164.8 | 42.7 | 63.8
Q,K,V | 1 | 16384 | 8 | 40 | 96.9 | 91.3 | 52.1
Q,K,V | 1 | 16384 | 8 | 80 | 142.9 | 101.5 | 65.6
Q,K,V | 1 | 16384 | 8 | 160 | 177.4 | 44.2 | 65.7
Q,K,V | 128 | 128 | 12 | 64 | 29.0 | 26.9 | 25.7
Q,K,V | 64 | 128 | 12 | 64 | 23.1 | 10.8 | 21.3
Q,K,V | 128 | 384 | 12 | 64 | 83.5 | 60.8 | 55.7
Q,K,V | 64 | 384 | 12 | 64 | 72.6 | 40.5 | 52.8
Q,K,V | 128 | 512 | 12 | 64 | 98.9 | 77.9 | 62.1
Q,K,V | 64 | 512 | 12 | 64 | 94.7 | 75.6 | 60.4
Q,KV | 32 | 512 | 64 | 32 | 85.9 | 41.1 | 41.1
Q,KV | 32 | 512 | 128 | 16 | 47.1 | 21.6 | 21.6
Q,KV | 16 | 1024 | 64 | 32 | 104.4 | 45.8 | 45.8
Q,KV | 16 | 1024 | 128 | 16 | 54.7 | 23.6 | 23.6
Q,KV | 8 | 2048 | 64 | 32 | 116.8 | 48.5 | 48.5
Q,KV | 8 | 2048 | 128 | 16 | 59.8 | 24.7 | 24.7
Q,KV | 4 | 4096 | 64 | 32 | 124.2 | 50.1 | 50.1
Q,KV | 4 | 4096 | 128 | 16 | 62.6 | 25.3 | 25.3
Q,KV | 2 | 8192 | 64 | 32 | 128.5 | 50.8 | 50.9
Q,KV | 2 | 8192 | 128 | 16 | 64.1 | 25.6 | 25.6
Q,KV | 1 | 16384 | 64 | 32 | 129.4 | 51.2 | 51.2
Q,KV | 1 | 16384 | 128 | 16 | 64.8 | 25.8 | 25.8
Q,KV | 1 | 4096 | 8 | 40 | 67.5 | 37.7 | 37.5
Q,KV | 1 | 4096 | 8 | 80 | 101.3 | 56.7 | 56.6
Q,KV | 1 | 4096 | 8 | 160 | 124.0 | 58.6 | 58.6
Q,KV | 4 | 4096 | 8 | 40 | 90.8 | 49.8 | 49.8
Q,KV | 4 | 4096 | 8 | 80 | 135.6 | 63.8 | 63.8
Q,KV | 4 | 4096 | 8 | 160 | 166.3 | 64.5 | 64.5
Q,KV | 1 | 16384 | 8 | 40 | 97.5 | 52.3 | 52.3
Q,KV | 1 | 16384 | 8 | 80 | 143.5 | 65.9 | 65.8
Q,KV | 1 | 16384 | 8 | 160 | 178.4 | 65.9 | 65.8
Q,KV | 128 | 128 | 12 | 64 | 26.8 | 48.1 | 30.9
Q,KV | 64 | 128 | 12 | 64 | 28.0 | 38.9 | 25.0
Q,KV | 128 | 384 | 12 | 64 | 97.7 | 61.1 | 61.0
Q,KV | 64 | 384 | 12 | 64 | 89.5 | 57.8 | 57.9
Q,KV | 128 | 512 | 12 | 64 | 111.9 | 66.7 | 66.9
Q,KV | 64 | 512 | 12 | 64 | 107.2 | 64.9 | 64.8
QKV | 32 | 512 | 64 | 32 | 77.2 | 84.7 | 39.3
QKV | 32 | 512 | 128 | 16 | 43.4 | 53.1 | 20.9
QKV | 16 | 1024 | 64 | 32 | 98.8 | 87.4 | 44.6
QKV | 16 | 1024 | 128 | 16 | 52.0 | 54.1 | 23.2
QKV | 8 | 2048 | 64 | 32 | 113.1 | 89.0 | 47.9
QKV | 8 | 2048 | 128 | 16 | 58.2 | 54.6 | 24.5
QKV | 4 | 4096 | 64 | 32 | 120.6 | 89.7 | 49.7
QKV | 4 | 4096 | 128 | 16 | 61.7 | 54.6 | 25.2
QKV | 2 | 8192 | 64 | 32 | 125.9 | 89.5 | 50.7
QKV | 2 | 8192 | 128 | 16 | 63.6 | 54.8 | 25.5
QKV | 1 | 16384 | 64 | 32 | 128.5 | 92.0 | 51.2
QKV | 1 | 16384 | 128 | 16 | 64.6 | 54.8 | 25.7
QKV | 1 | 4096 | 8 | 40 | 60.2 | **69.8** | 38.1
QKV | 1 | 4096 | 8 | 80 | 101.6 | 75.2 | 56.7
QKV | 1 | 4096 | 8 | 160 | 130.2 | 41.2 | 58.4
QKV | 4 | 4096 | 8 | 40 | 90.6 | **91.0** | 49.5
QKV | 4 | 4096 | 8 | 80 | 133.6 | 98.1 | 62.8
QKV | 4 | 4096 | 8 | 160 | 165.3 | 43.7 | 63.9
QKV | 1 | 16384 | 8 | 40 | 97.2 | 92.8 | 52.1
QKV | 1 | 16384 | 8 | 80 | 143.0 | 103.1 | 65.6
QKV | 1 | 16384 | 8 | 160 | 177.6 | 44.5 | 65.7
QKV | 128 | 128 | 12 | 64 | 31.1 | 65.9 | 27.6
QKV | 64 | 128 | 12 | 64 | 26.1 | 49.8 | 23.5
QKV | 128 | 384 | 12 | 64 | 84.6 | 88.5 | 56.1
QKV | 64 | 384 | 12 | 64 | 79.1 | 80.3 | 53.5
QKV | 128 | 512 | 12 | 64 | 97.3 | 114.2 | 62.2
QKV | 64 | 512 | 12 | 64 | 95.9 | 110.7 | 60.6
QKV | 4 | 2048 | 32 | 128 | 125.26 | 44.72 | 78.15
QKV | 4 | 4096 | 32 | 128 | 141.62 | 46.29 | 85.84
QKV | 8 | 2048 | 32 | 128 | 127.40 | 45.49 | 78.75
QKV | 8 | 4096 | 32 | 128 | 144.24 | 46.60 | 86.95

### Known Issues

NVCC uses huge memory while compiling flash attention CUDA kernel. Linux
build with CUDA might fail when machine has limited memory while number
of CPUs is large. Walkaround is to use a build machine with larger
memory, or use argument like `--nvcc_threads 1` to limit nvcc threads in
build.

### Motivation and Context
Increases speed and efficiency of MHA or Packed MHA.

---------

Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: tlwu@microsoft.com <tlwu@a100.crj0ad2y1kku1j4yxl4sj10o4e.gx.internal.cloudapp.net>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triage:approved Approved for cherrypicks for release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants