Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement FlashAttention for CPU #20805

Merged
merged 36 commits into from
Jul 11, 2024
Merged

Implement FlashAttention for CPU #20805

merged 36 commits into from
Jul 11, 2024

Conversation

duanqn
Copy link
Contributor

@duanqn duanqn commented May 24, 2024

Description

Implement FlashAttention and FlashAttention-2 for MultiHeadAttention on CPU.

Motivation and Context

Accelerate the execution of MultiHeadAttention.

Current performance: 10ms vs 16ms (com.microsoft.MultiHeadAttention) on my Linux machine and 10ms vs 38ms (com.microsoft.MultiHeadAttention) on my Windows machine. May need further optimizations.

@duanqn
Copy link
Contributor Author

duanqn commented Jun 19, 2024

Test failing: MultiHeadAttentionTest.CrossAttention_DiffSequenceLengths

Edit: passed

Copy link

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PREfast found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

@tianleiwu tianleiwu mentioned this pull request Jun 20, 2024
3 tasks
@duanqn duanqn force-pushed the qiduan/flash branch 2 times, most recently from f858430 to 599ac3f Compare June 20, 2024 07:34
@tianleiwu tianleiwu marked this pull request as ready for review June 20, 2024 17:01
@tianleiwu tianleiwu requested a review from a team as a code owner June 20, 2024 17:01
@duanqn
Copy link
Contributor Author

duanqn commented Jun 21, 2024

Environment Variables:
ORT_DISABLE_FLASH_ATTENTION=0

format causal batch seqlen heads h_dim ms TFLOPS kernel
Q,K,V False 1 128 32 128 1.59 0.17 CPU:Flash
Q,K,V False 1 256 32 128 2.74 0.39 CPU:Flash
Q,K,V False 1 512 32 128 8.28 0.52 CPU:Flash
Q,K,V False 1 1024 32 128 26.43 0.65 CPU:Flash
Q,K,V False 1 2048 32 128 88.92 0.77 CPU:Flash
Q,K,V False 1 4096 8 40 36.26 0.59 CPU:Flash
Q,K,V False 1 4096 8 80 54.36 0.79 CPU:Flash
Q,K,V False 1 4096 8 160 99.28 0.87 CPU:Flash
Q,K,V False 4 4096 8 40 144.85 0.59 CPU:Flash
Q,K,V False 4 4096 8 80 217.08 0.79 CPU:Flash
Q,K,V False 4 4096 8 160 400.06 0.86 CPU:Flash
Q,K,V False 1 16384 8 40 570.16 0.60 CPU:Flash
Q,K,V False 1 16384 8 80 854.11 0.80 CPU:Flash
Q,K,V False 1 16384 8 160 1511.06 0.91 CPU:Flash
Q,K,V False 128 128 12 64 29.84 0.22 CPU:Flash
Q,K,V False 64 128 12 64 14.82 0.22 CPU:Flash
Q,K,V False 128 384 12 64 131.07 0.44 CPU:Flash
Q,K,V False 64 384 12 64 65.70 0.44 CPU:Flash
Q,K,V False 128 512 12 64 203.86 0.51 CPU:Flash
Q,K,V False 64 512 12 64 99.83 0.52 CPU:Flash
Q,K,V False 4 2048 32 128 350.01 0.79 CPU:Flash
Q,K,V False 4 4096 32 128 1278.42 0.86 CPU:Flash
Q,K,V False 8 2048 32 128 698.98 0.79 CPU:Flash
Q,K,V False 8 4096 32 128 2547.00 0.86 CPU:Flash

Environment Variables:
ORT_DISABLE_FLASH_ATTENTION=1

format causal batch seqlen heads h_dim ms TFLOPS kernel
Q,K,V False 1 128 32 128 1.43 0.19 CPU:Unfused
Q,K,V False 1 256 32 128 3.24 0.33 CPU:Unfused
Q,K,V False 1 512 32 128 11.26 0.38 CPU:Unfused
Q,K,V False 1 1024 32 128 36.88 0.47 CPU:Unfused
Q,K,V False 1 2048 32 128 106.25 0.65 CPU:Unfused
Q,K,V False 1 4096 8 40 49.43 0.43 CPU:Unfused
Q,K,V False 1 4096 8 80 75.99 0.57 CPU:Unfused
Q,K,V False 1 4096 8 160 137.47 0.62 CPU:Unfused
Q,K,V False 4 4096 8 40 194.25 0.44 CPU:Unfused
Q,K,V False 4 4096 8 80 298.62 0.58 CPU:Unfused
Q,K,V False 4 4096 8 160 540.00 0.64 CPU:Unfused
Q,K,V False 1 16384 8 40 962.66 0.36 CPU:Unfused
Q,K,V False 1 16384 8 80 1389.89 0.49 CPU:Unfused
Q,K,V False 1 16384 8 160 2605.56 0.53 CPU:Unfused
Q,K,V False 128 128 12 64 33.08 0.19 CPU:Unfused
Q,K,V False 64 128 12 64 16.26 0.20 CPU:Unfused
Q,K,V False 128 384 12 64 149.92 0.39 CPU:Unfused
Q,K,V False 64 384 12 64 75.20 0.39 CPU:Unfused
Q,K,V False 128 512 12 64 234.68 0.44 CPU:Unfused
Q,K,V False 64 512 12 64 117.20 0.44 CPU:Unfused
Q,K,V False 4 2048 32 128 409.42 0.67 CPU:Unfused
Q,K,V False 4 4096 32 128 1561.20 0.70 CPU:Unfused
Q,K,V False 8 2048 32 128 814.60 0.67 CPU:Unfused
Q,K,V False 8 4096 32 128 3112.91 0.71 CPU:Unfused

@tianleiwu
Copy link
Contributor

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline

Copy link

Azure Pipelines successfully started running 3 pipeline(s).

Copy link

Azure Pipelines successfully started running 10 pipeline(s).

1 similar comment
Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@tianleiwu
Copy link
Contributor

/azp run Windows ARM64 QNN CI Pipeline,Windows x64 QNN CI Pipeline,Windows CPU CI Pipeline,Windows GPU CI Pipeline,Windows GPU TensorRT CI Pipeline,ONNX Runtime Web CI Pipeline,Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,orttraining-amd-gpu-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,onnxruntime-binary-size-checks-ci-pipeline,Big Models,Linux Android Emulator QNN CI Pipeline

@tianleiwu
Copy link
Contributor

/azp run Android CI Pipeline,iOS CI Pipeline,ONNX Runtime React Native CI Pipeline

Copy link

Azure Pipelines successfully started running 3 pipeline(s).

Copy link

Azure Pipelines successfully started running 10 pipeline(s).

1 similar comment
Copy link

Azure Pipelines successfully started running 10 pipeline(s).

@tianleiwu tianleiwu requested a review from yufenglee July 11, 2024 19:00
@yufenglee
Copy link
Member

@duanqn, thank you very much for your contribution, Qingnan!

@yufenglee yufenglee merged commit 80b56fe into microsoft:main Jul 11, 2024
86 of 88 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants