Skip to content

YichengDWu/FlashAttention.jl

Repository files navigation

FlashAttention

Stable Dev Build Status Coverage

This is a Julia implementation of the Flash Attention algorithm.

Usage

using FlashAttention, CUDA

Q = CUDA.randn(Float16, 64, 1024, 48, 3);
K = CUDA.randn(Float16, 64, 1024, 48, 3);
V = CUDA.randn(Float16, 64, 1024, 48, 3);

flash_attention(Q,K,V)

Profiling

Please refer to the file flash_attention.ncu-rep. This is not the fastest implementation for

  1. we do not use tensor cores as in the C++ implmentation,
  2. CUDA.jl doese not yet support asynchronous copy from global memory to shared memory, and
  3. this kernel's theoretical occupancy (12.5%) is limited by the required amount of shared memory.

Future work

I plan to implement it in the future using MoYe.jl to achieve competitive performance.

Releases

No releases published

Languages