How FlashAttention Works

08-09

4-5x faster than PyTorch, 2-3x faster training speeds... FlashAttention and its variants seem like magic. But how does it work?

Memory Types

Before FlashAttention, we should first get acquainted with the memory hierarchy of a modern GPU. We'll look at two types of memory in particular: high bandwidth memory (HBM, global memory) and shared memory (SRAM). For context, an A100 provides 40/80GB of HBM with a bandwidth of ~1.5TB/s and 192KB of SRAM with a bandwidth of ~19TB/s. We care about these two memories because as programmers, we can control where we put our data, and the capacity-speed tradeoff between HBM/SRAM opens up possibilities for serious performance gains.

Naive Attention

Regardless of implementation, we will start with Q,K,VRN imesdQ, K, V \in \mathbb{R}^{N \ imes d} in HBM, where NN is the sequence length and dd is the head dimension. This is because NN and dd tend to fall in the 1000-8000 and 64-128 range respectively, so even one NtimesdN \\times d matrix in fp16 with N=8000N = 8000 and d=128d = 128 would be 2048KB in memory, which cannot possibly fit in SRAM.

Naive Attention Step 1

In a naive implementation of attention, we would call a generalized matrix multiplication (GEMM) on QQ and KTK^T and materialize the result SS in HBM.

Naive Attention Step 2

We then load SS from HBM, apply row-wise softmax, and materialize the result PP to HBM.

Naive Attention Step 3

Finally, we call a GEMM on PP and VV and write the result OO to HBM.

Naive Attention Step 4

In total, naive attention consists of 4N2+4Nd4N^2 + 4Nd element reads/writes from HBM and 4N2d4N^2d FLOPs total (for the two GEMMs). The arithmetic intensity is therefore \ rac{4N^2d}{2(4N^2+4Nd)} = \ rac{Nd}{2(N+d)}, assuming bf16/fp16. For NggdN \\gg d, this simplifies to d/2d/2, and assuming d=128d = 128, the intensity of naive attention is 64 FLOPs/byte. Given that an A100 (40GB, SXM4) with tensor cores peaks at 312 TFLOP/s and HBM bandwidth is around 1.6TB/s, the roofline ratio is 312/1.6=195312/1.6 = 195 FLOPs/byte. Our well-below-roofline intensity signals that the naive attention kernel is (heavily) memory-bound.

Notice that the naive way writes NtimesNN \\times N matrices into HBM and retrieves them right after, which seems like a waste (it is!). This is the core motivation behind flash attention: instead of materializing SS and PP in HBM, we want to keep them in on-chip SRAM and only write back the result OO.

There are some obstacles, though: the obvious issue is that the full matrices do not fit in SRAM. We can introduce tiling to alleviate this, partitioning Q,K,VQ, K, V into smaller blocks and calculating QKTQK^T one block at a time (more detail on this later in the article).

The second issue is that the denominator in the softmax needs to be summed across the entire row.

\ ext{Softmax}(x_i) = \ rac{e^{x_i}}{\sum_j e^{x_j}}

This wouldn't be an issue normally, but with blocks, we don't have access to the full row. Luckily, there's this neat online softmax algorithm that performs a running softmax and rescales previous partial sums. Let's see how this works.

Online Softmax

For simplicity, we will show how online softmax works on S and V, both split into two blocks as shown below:

Online Softmax Step 1

On the first block:

m(1)=rowmax((S(1)))RBr,(1)=rowsum((eS(1)m(1)))RBr,P~(1)=diag(((1)))1eS(1)m(1)RBr×Bc,O(1)=P~(1)V(1)=diag(((1)))1eS(1)m(1)V(1)RBr×d.\begin{aligned} m^{(1)} &= \operatorname{rowmax}(\big(S^{(1)}\big)) \in \mathbb{R}^{B_r},\\ \ell^{(1)} &= \operatorname{rowsum}(\left(e^{\,S^{(1)} - m^{(1)}}\right)) \in \mathbb{R}^{B_r},\\ \tilde{P}^{(1)} &= \operatorname{diag}(\big(\ell^{(1)}\big))^{-1} \, e^{\,S^{(1)} - m^{(1)}} \in \mathbb{R}^{B_r \times B_c},\\ O^{(1)} &= \tilde{P}^{(1)} V^{(1)} = \operatorname{diag}(\big(\ell^{(1)}\big))^{-1} \, e^{\,S^{(1)} - m^{(1)}} V^{(1)} \in \mathbb{R}^{B_r \times d}. \end{aligned}

or visually,

Online Softmax Step 2

We're computing eSi,jm(1)e^{S_{i,j}-m^{(1)}} for the softmax numerator and adding up eSi,jm(1)e^{S_{i, j}-m^{(1)}} for the denominator, which might seem redundant because the m(1)m^{(1)}s cancel. This is (1) for numerical stability, bounding all exponential arguments to be at most 0, and (2) allows us to rescale the normalizer ll and partially accumulated output OO when we find a higher mm.

On the second block:

m(2)=max((m(1),rowmax(S(2))))=m,(2)=em(1)m(2)(1)+rowsum((eS(2)m(2)))=rowsum((eS(1)m))+rowsum((eS(2)m))=,P~(2)=diag(((2)))1eS(2)m(2),O(2)=diag(((1)/(2)))1O(1)+P~(2)V(2)=diag(((2)))1eS(1)mV(1)+diag(((2)))1eS(2)mV(2)=O. \begin{aligned} m^{(2)} &= \max(\big(m^{(1)},\, \operatorname{rowmax}(S^{(2)})\big)) = m,\\ \ell^{(2)} &= e^{\,m^{(1)}-m^{(2)}} \,\ell^{(1)} + \operatorname{rowsum}(\left(e^{\,S^{(2)}-m^{(2)}}\right)) = \operatorname{rowsum}(\left(e^{\,S^{(1)}-m}\right)) + \operatorname{rowsum}(\left(e^{\,S^{(2)}-m}\right)) = \ell,\\ \tilde{P}^{(2)} &= \operatorname{diag}(\big(\ell^{(2)}\big))^{-1} \, e^{\,S^{(2)}-m^{(2)}},\\ O^{(2)} &= \operatorname{diag}(\big(\ell^{(1)}/\ell^{(2)}\big))^{-1} O^{(1)} + \tilde{P}^{(2)} V^{(2)} = \operatorname{diag}(\big(\ell^{(2)}\big))^{-1} e^{\,S^{(1)}-m} V^{(1)} + \operatorname{diag}(\big(\ell^{(2)}\big))^{-1} e^{\,S^{(2)}-m} V^{(2)} = O. \end{aligned}

or visually,

Online Softmax Step 3

Notice the iterative nature and the rescaling. From this simplified example, we can generalize to a scenario with N blocks in every row.

Online Softmax Step 4

Flash Attention

Now, let's see flash attention in action, with tiling and online softmax. We first tile QQ into BrtimesdB_r \\times d blocks and K,VK, V into BctimesdB_c \\times d blocks. Our output, OO, will have BrtimesdB_r \\times d blocks (due to the previous block dimensions).

Flash Attention Step 1

Say we go through QQ's blocks with iterator i and K,VK, V's blocks with iterator j. For each distinct i and j, we read Q,KQ, K blocks from HBM and multiply them in SRAM to obtain the SS block. We then run the online softmax (still in SRAM) to obtain the PP block.

Flash Attention Step 2

We now read the VV block from HBM and multiply P,VP, V together into the OO block, which is rescaled and accumulated into the final OO block (à la online softmax) and written to HBM.

Flash Attention Step 3

FLOP-wise, FlashAttention is the same as naive attention: 4N2d4N^2d (two GEMMs). The dominant HBM traffic, in the idealized case, is just reading Q,K,VQ,K,V and writing O: 4Nd4Nd elements, i.e. 8Nd bytes for bf16/fp16. The resulting arithmetic intensity is \ rac{4N^2d}{8Nd}=\ rac{N}{2} FLOPs/byte. For our case, Napprox8000N \\approx 8000, so our intensity is 40004000 FLOPs/byte and FlashAttention is comfortably compute-bound.

FlashAttention is operator fusion done right. Instead of three kernels writing massive intermediate matrices to slow HBM, FA fuses everything into one kernel that keeps S and P in fast SRAM. The core insight is that modern GPUs are so fast at compute that you're basically always waiting on memory. We can boost attention from 64 to 4000 FLOPs/byte; we just had pay attention (...) to the memory hierarchy.

Derivatives for the Backward Pass

Let's look at the math behind backpropagating through flash attention. Let Q,K,VRN imesdQ, K, V \in \mathbb{R}^{N \ imes d}, S=QK opS = QK^{\ op}, row-wise P=softmax(S)P=\mathrm{softmax}(S), and O=PVO = PV. δ()\delta(\cdot) denotes the gradient operator and the norm notation refers to the Frobenius inner product.

δL=dX,δX=tr(dXδX)\delta L = \langle dX,\,\delta X \rangle = \operatorname{tr}(dX^{\top}\,\delta X)

An identity that we will use frequently:

A,BC=tr(ABC)=tr(CAB)=BA,C=AC,B.\langle A,\,BC\rangle = \operatorname{tr}(A^{\top}BC) = \operatorname{tr}(CA^{\top}B) = \langle B^{\top}A,\,C\rangle = \langle AC^{\top},\,B\rangle.

From O=PVO = PV,

δL=dO,δO=dO,δPV+PδV=tr(dOδPV)+tr(dOPδV)=tr(VdOδP)+tr(PdOδV)=dOV,δP+PdO,δV.\begin{aligned} \delta L &= \langle dO,\,\delta O\rangle \\ &= \langle dO,\,\delta P\,V + P\,\delta V\rangle \\ &= \operatorname{tr}(dO^{\top}\,\delta P\,V) + \operatorname{tr}(dO^{\top}\,P\,\delta V) \\ &= \operatorname{tr}(V\,dO^{\top}\,\delta P) + \operatorname{tr}(P^{\top} dO\,\delta V) \\ &= \langle dO\,V^{\top},\,\delta P\rangle + \langle P^{\top} dO,\,\delta V\rangle. \end{aligned}
δO=δPV+PδVdV=P opdO,    dP=dOV op.\delta O = \delta P\,V + P\,\delta V\quad\Rightarrow\quad dV = P^{\ op} dO,\;\; dP = dO\,V^{\ op}.

From P=softmax(S)P = \mathrm{softmax}(S) (row-wise)

For a row sRNs \in \mathbb{R}^N with p=softmax(s)p=\mathrm{softmax}(s), the Jacobian is symmetric:

J(p)= fracps=diag(p)pp opJ(p)=\ frac{\partial p}{\partial s}=\operatorname{diag}(p) - pp^{\ op}

Applied row-wise, this gives

dSi=J(pi)dPi     extforeachrow  i.dS_i = J(p_i)\, dP_i\;\;\ ext{for each row}\; i.

From S=QK opS = QK^{\ op}

δL=dS,δS=dS,δQK+QδK=tr(dSδQK)+tr(dSQδK)=tr(KdSδQ)+tr(δKdSQ)=dSK,δQ+dSQ,δK.\begin{aligned} \delta L &= \langle dS,\,\delta S\rangle \\ &= \langle dS,\,\delta Q\,K^{\top} + Q\,\delta K^{\top}\rangle \\ &= \operatorname{tr}(dS^{\top}\,\delta Q\,K^{\top}) + \operatorname{tr}(dS^{\top}\,Q\,\delta K^{\top}) \\ &= \operatorname{tr}(K^{\top} dS^{\top}\,\delta Q) + \operatorname{tr}(\delta K^{\top} dS^{\top} Q) \\ &= \langle dS\,K,\,\delta Q\rangle + \langle dS^{\top} Q,\,\delta K\rangle. \end{aligned}
δS=δQK op+QδK opdQ=dSK,    dK=dS opQ.\delta S = \delta Q\,K^{\ op} + Q\,\delta K^{\ op}\quad\Rightarrow\quad dQ = dS\,K,\;\; dK = dS^{\ op}Q.

Collected results (starting from upstream dORN imesddO \in \mathbb{R}^{N\ imes d}):

dV=PdO,dP=dOV,dS=dsoftmax(dP)(row-wise J(p)=diag(p)pp),dQ=dSK,dK=dSQ.\begin{aligned} dV &= P^{\top} dO,\\ dP &= dO\,V^{\top},\\ dS &= \mathrm{dsoftmax}(dP)\quad (\text{row-wise }J(p)=\operatorname{diag}(p)-pp^{\top}),\\ dQ &= dS\,K,\\ dK &= dS^{\top} Q. \end{aligned}

This is the math that FA's backward kernel must implement.

What's next for me: actually reading the kernels. The only way to truly understand the forward and backward passes is to go into the code, written in CUTLASS and CUDA C++. That means mapping the pseudocode onto warps/blocks and going into each detail until the shapes and shuffles feel natural. More to come!

References

Boehm, Simon. “How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog.” December 31, 2022. https://siboehm.com/articles/22/CUDA-MMM.

Dao, Tri; Fu, Daniel Y.; Ermon, Stefano; Rudra, Atri; Ré, Christopher. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” arXiv:2205.14135, 2022. https://doi.org/10.48550/arXiv.2205.14135.

Dao, Tri. “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.” arXiv:2307.08691, 2023. https://doi.org/10.48550/arXiv.2307.08691.

NVIDIA Corporation. CUDA C++ Programming Guide, Release 13.0. Aug 1, 2025. https://docs.nvidia.com/cuda/cuda-c-programming-guide/.