flash attention
4-5x faster than PyTorch, 2-3x faster training speeds... FlashAttention and its variants seem like magic. But how does it work?
Memory Types
Before FlashAttention, we should first get acquainted with the memory hierarchy of a modern GPU. We'll look at two types of memory in particular: high bandwidth memory (HBM, global memory) and shared memory (SRAM). For context, an A100 provides 40/80GB of HBM with a bandwidth of ~1.5TB/s and 192KB of SRAM with a bandwidth of ~19TB/s. We care about these two memories because as programmers, we can control where we put our data, and the capacity-speed tradeoff between HBM/SRAM opens up possibilities for serious performance gains.
Naive Attention
Regardless of implementation, we will start with in HBM, where is the sequence length and is the head dimension. This is because and tend to fall in the 1000-8000 and 64-128 range respectively, so even one matrix in fp16 with and would be 2048KB in memory, which cannot possibly fit in SRAM.
In a naive implementation of attention, we would call a generalized matrix multiplication (GEMM) on and and materialize the result in HBM.
We then load from HBM, apply row-wise softmax, and materialize the result to HBM.
Finally, we call a GEMM on and and write the result to HBM.
In total, naive attention consists of element reads/writes from HBM and FLOPs total (for the two GEMMs). The arithmetic intensity is therefore \rac{4N^2d}{2(4N^2+4Nd)} = \rac{Nd}{2(N+d)}, assuming bf16/fp16. For , this simplifies to , and assuming , the intensity of naive attention is 64 FLOPs/byte. Given that an A100 (40GB, SXM4) with tensor cores peaks at 312 TFLOP/s and HBM bandwidth is around 1.6TB/s, the roofline ratio is FLOPs/byte. Our well-below-roofline intensity signals that the naive attention kernel is (heavily) memory-bound.
Notice that the naive way writes matrices into HBM and retrieves them right after, which seems like a waste (it is!). This is the core motivation behind flash attention: instead of materializing and in HBM, we want to keep them in on-chip SRAM and only write back the result .
There are some obstacles, though: the obvious issue is that the full matrices do not fit in SRAM. We can introduce tiling to alleviate this, partitioning into smaller blocks and calculating one block at a time (more detail on this later in the article).
The second issue is that the denominator in the softmax needs to be summed across the entire row.
This wouldn't be an issue normally, but with blocks, we don't have access to the full row. Luckily, there's this neat online softmax algorithm that performs a running softmax and rescales previous partial sums. Let's see how this works.
Online Softmax
For simplicity, we will show how online softmax works on S and V, both split into two blocks as shown below:
On the first block:
or visually,
We're computing for the softmax numerator and adding up for the denominator, which might seem redundant because the s cancel. This is (1) for numerical stability, bounding all exponential arguments to be at most 0, and (2) allows us to rescale the normalizer and partially accumulated output when we find a higher .
On the second block:
or visually,
Notice the iterative nature and the rescaling. From this simplified example, we can generalize to a scenario with N blocks in every row.
Flash Attention
Now, let's see flash attention in action, with tiling and online softmax. We first tile into blocks and into blocks. Our output, , will have blocks (due to the previous block dimensions).
Say we go through 's blocks with iterator i and 's blocks with iterator j. For each distinct i and j, we read blocks from HBM and multiply them in SRAM to obtain the block. We then run the online softmax (still in SRAM) to obtain the block.
We now read the block from HBM and multiply together into the block, which is rescaled and accumulated into the final block (à la online softmax) and written to HBM.
FLOP-wise, FlashAttention is the same as naive attention: (two GEMMs). The dominant HBM traffic, in the idealized case, is just reading and writing O: elements, i.e. 8Nd bytes for bf16/fp16. The resulting arithmetic intensity is \rac{4N^2d}{8Nd}=\rac{N}{2} FLOPs/byte. For our case, , so our intensity is FLOPs/byte and FlashAttention is comfortably compute-bound.
FlashAttention is operator fusion done right. Instead of three kernels writing massive intermediate matrices to slow HBM, FA fuses everything into one kernel that keeps S and P in fast SRAM. The core insight is that modern GPUs are so fast at compute that you're basically always waiting on memory. We can boost attention from 64 to 4000 FLOPs/byte; we just had pay attention (...) to the memory hierarchy.
References
Boehm, Simon. “How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog.” December 31, 2022. https://siboehm.com/articles/22/CUDA-MMM.
Dao, Tri; Fu, Daniel Y.; Ermon, Stefano; Rudra, Atri; Ré, Christopher. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” arXiv:2205.14135, 2022. https://doi.org/10.48550/arXiv.2205.14135.
Dao, Tri. “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.” arXiv:2307.08691, 2023. https://doi.org/10.48550/arXiv.2307.08691.
NVIDIA Corporation. CUDA C++ Programming Guide, Release 13.0. Aug 1, 2025. https://docs.nvidia.com/cuda/cuda-c-programming-guide/.