flash attention

4-5x faster than PyTorch, 2-3x faster training speeds... FlashAttention and its variants seem like magic. But how does it work?

Memory Types

Before FlashAttention, we should first get acquainted with the memory hierarchy of a modern GPU. We'll look at two types of memory in particular: high bandwidth memory (HBM, global memory) and shared memory (SRAM). For context, an A100 provides 40/80GB of HBM with a bandwidth of ~1.5TB/s and 192KB of SRAM with a bandwidth of ~19TB/s. We care about these two memories because as programmers, we can control where we put our data, and the capacity-speed tradeoff between HBM/SRAM opens up possibilities for serious performance gains.

Naive Attention

Regardless of implementation, we will start with Q,K,VRN imesdQ, K, V \in \mathbb{R}^{N \ imes d} in HBM, where NN is the sequence length and dd is the head dimension. This is because NN and dd tend to fall in the 1000-8000 and 64-128 range respectively, so even one N×dN \times d matrix in fp16 with N=8000N = 8000 and d=128d = 128 would be 2048KB in memory, which cannot possibly fit in SRAM.

Naive Attention Step 1

In a naive implementation of attention, we would call a generalized matrix multiplication (GEMM) on QQ and KTK^T and materialize the result SS in HBM.

Naive Attention Step 2

We then load SS from HBM, apply row-wise softmax, and materialize the result PP to HBM.

Naive Attention Step 3

Finally, we call a GEMM on PP and VV and write the result OO to HBM.

Naive Attention Step 4

In total, naive attention consists of 4N2+4Nd4N^2 + 4Nd element reads/writes from HBM and 4N2d4N^2d FLOPs total (for the two GEMMs). The arithmetic intensity is therefore \ rac{4N^2d}{2(4N^2+4Nd)} = \ rac{Nd}{2(N+d)}, assuming bf16/fp16. For NdN \gg d, this simplifies to d/2d/2, and assuming d=128d = 128, the intensity of naive attention is 64 FLOPs/byte. Given that an A100 (40GB, SXM4) with tensor cores peaks at 312 TFLOP/s and HBM bandwidth is around 1.6TB/s, the roofline ratio is 312/1.6=195312/1.6 = 195 FLOPs/byte. Our well-below-roofline intensity signals that the naive attention kernel is (heavily) memory-bound.

Notice that the naive way writes N×NN \times N matrices into HBM and retrieves them right after, which seems like a waste (it is!). This is the core motivation behind flash attention: instead of materializing SS and PP in HBM, we want to keep them in on-chip SRAM and only write back the result OO.

There are some obstacles, though: the obvious issue is that the full matrices do not fit in SRAM. We can introduce tiling to alleviate this, partitioning Q,K,VQ, K, V into smaller blocks and calculating QKTQK^T one block at a time (more detail on this later in the article).

The second issue is that the denominator in the softmax needs to be summed across the entire row.

\ ext{Softmax}(x_i) = \ rac{e^{x_i}}{\sum_j e^{x_j}}

This wouldn't be an issue normally, but with blocks, we don't have access to the full row. Luckily, there's this neat online softmax algorithm that performs a running softmax and rescales previous partial sums. Let's see how this works.

Online Softmax

For simplicity, we will show how online softmax works on S and V, both split into two blocks as shown below:

Online Softmax Step 1

On the first block:

m(1)=rowmax((S(1)))RBr,(1)=rowsum((eS(1)m(1)))RBr,P~(1)=diag(((1)))1eS(1)m(1)RBr×Bc,O(1)=P~(1)V(1)=diag(((1)))1eS(1)m(1)V(1)RBr×d.\begin{aligned} m^{(1)} &= \operatorname{rowmax}(\big(S^{(1)}\big)) \in \mathbb{R}^{B_r},\\ \ell^{(1)} &= \operatorname{rowsum}(\left(e^{\,S^{(1)} - m^{(1)}}\right)) \in \mathbb{R}^{B_r},\\ \tilde{P}^{(1)} &= \operatorname{diag}(\big(\ell^{(1)}\big))^{-1} \, e^{\,S^{(1)} - m^{(1)}} \in \mathbb{R}^{B_r \times B_c},\\ O^{(1)} &= \tilde{P}^{(1)} V^{(1)} = \operatorname{diag}(\big(\ell^{(1)}\big))^{-1} \, e^{\,S^{(1)} - m^{(1)}} V^{(1)} \in \mathbb{R}^{B_r \times d}. \end{aligned}

or visually,

Online Softmax Step 2

We're computing eSi,jm(1)e^{S_{i,j}-m^{(1)}} for the softmax numerator and adding up eSi,jm(1)e^{S_{i, j}-m^{(1)}} for the denominator, which might seem redundant because the m(1)m^{(1)}s cancel. This is (1) for numerical stability, bounding all exponential arguments to be at most 0, and (2) allows us to rescale the normalizer ll and partially accumulated output OO when we find a higher mm.

On the second block:

m(2)=max((m(1),rowmax(S(2))))=m,(2)=em(1)m(2)(1)+rowsum((eS(2)m(2)))=rowsum((eS(1)m))+rowsum((eS(2)m))=,P~(2)=diag(((2)))1eS(2)m(2),O(2)=diag(((1)/(2)))1O(1)+P~(2)V(2)=diag(((2)))1eS(1)mV(1)+diag(((2)))1eS(2)mV(2)=O. \begin{aligned} m^{(2)} &= \max(\big(m^{(1)},\, \operatorname{rowmax}(S^{(2)})\big)) = m,\\ \ell^{(2)} &= e^{\,m^{(1)}-m^{(2)}} \,\ell^{(1)} + \operatorname{rowsum}(\left(e^{\,S^{(2)}-m^{(2)}}\right)) = \operatorname{rowsum}(\left(e^{\,S^{(1)}-m}\right)) + \operatorname{rowsum}(\left(e^{\,S^{(2)}-m}\right)) = \ell,\\ \tilde{P}^{(2)} &= \operatorname{diag}(\big(\ell^{(2)}\big))^{-1} \, e^{\,S^{(2)}-m^{(2)}},\\ O^{(2)} &= \operatorname{diag}(\big(\ell^{(1)}/\ell^{(2)}\big))^{-1} O^{(1)} + \tilde{P}^{(2)} V^{(2)} = \operatorname{diag}(\big(\ell^{(2)}\big))^{-1} e^{\,S^{(1)}-m} V^{(1)} + \operatorname{diag}(\big(\ell^{(2)}\big))^{-1} e^{\,S^{(2)}-m} V^{(2)} = O. \end{aligned}

or visually,

Online Softmax Step 3

Notice the iterative nature and the rescaling. From this simplified example, we can generalize to a scenario with N blocks in every row.

Online Softmax Step 4

Flash Attention

Now, let's see flash attention in action, with tiling and online softmax. We first tile QQ into Br×dB_r \times d blocks and K,VK, V into Bc×dB_c \times d blocks. Our output, OO, will have Br×dB_r \times d blocks (due to the previous block dimensions).

Flash Attention Step 1

Say we go through QQ's blocks with iterator i and K,VK, V's blocks with iterator j. For each distinct i and j, we read Q,KQ, K blocks from HBM and multiply them in SRAM to obtain the SS block. We then run the online softmax (still in SRAM) to obtain the PP block.

Flash Attention Step 2

We now read the VV block from HBM and multiply P,VP, V together into the OO block, which is rescaled and accumulated into the final OO block (à la online softmax) and written to HBM.

Flash Attention Step 3

FLOP-wise, FlashAttention is the same as naive attention: 4N2d4N^2d (two GEMMs). The dominant HBM traffic, in the idealized case, is just reading Q,K,VQ,K,V and writing O: 4Nd4Nd elements, i.e. 8Nd bytes for bf16/fp16. The resulting arithmetic intensity is \ rac{4N^2d}{8Nd}=\ rac{N}{2} FLOPs/byte. For our case, N8000N \approx 8000, so our intensity is 40004000 FLOPs/byte and FlashAttention is comfortably compute-bound.

FlashAttention is operator fusion done right. Instead of three kernels writing massive intermediate matrices to slow HBM, FA fuses everything into one kernel that keeps S and P in fast SRAM. The core insight is that modern GPUs are so fast at compute that you're basically always waiting on memory. We can boost attention from 64 to 4000 FLOPs/byte; we just had pay attention (...) to the memory hierarchy.

References

Boehm, Simon. “How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog.” December 31, 2022. https://siboehm.com/articles/22/CUDA-MMM.

Dao, Tri; Fu, Daniel Y.; Ermon, Stefano; Rudra, Atri; Ré, Christopher. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” arXiv:2205.14135, 2022. https://doi.org/10.48550/arXiv.2205.14135.

Dao, Tri. “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.” arXiv:2307.08691, 2023. https://doi.org/10.48550/arXiv.2307.08691.

NVIDIA Corporation. CUDA C++ Programming Guide, Release 13.0. Aug 1, 2025. https://docs.nvidia.com/cuda/cuda-c-programming-guide/.