understanding attention
We examine the transformer architecture from a technical perspective, focusing on the mathematical foundations of attention. The original transformer (Vaswani et al., 2017) comprises six encoder layers, six decoder layers, a linear projection, and a softmax activation. Unlike sequential RNNs, transformers enable massive parallelization by expressing operations as batched GEMMs, achieving superior GPU utilization.
Each encoder layer consists of self-attention followed by a feed-forward network. We examine each component systematically.
self-attention
For simplicity, we begin with a single sequence (batch size B = 1). The input sentence undergoes tokenization (BPE, SentencePiece) yielding T tokens. Each token becomes an embedding of dimension D (typically 512) through a learned embedding matrix, resulting in input embeddings X with shape T × D.
The attention mechanism transforms embeddings through three learned projection matrices: query projection WQ (shape D × d), key projection WK (shape D × d), and value projection WV (shape D × d), where d represents the attention dimension. The dimension d may equal D for full-rank attention, or be much smaller for bottleneck attention, which compresses information to reduce memory and computation.
These projections yield:
- Q = X × WQ (queries, shape T × d)
- K = X × WK (keys, shape T × d)
- V = X × WV (values, shape T × d)
the core computation
The attention mechanism computes context through:
Attention(Q,K,V) = softmax((QKT)/sqrt(d))V
The matrix multiplication QKT computes all pairwise dot products between query and key vectors. This yields attention weights A with shape T × T. This pairwise computation makes attention quadratic in memory and computation - we calculate interactions between every token pair.
Think of attention like a database lookup. The query asks "what information do I need?" The key advertises "what information do I contain?" The value provides "here's my actual content." When computing QKT, we ask: "how well does what I seek match what you offer?" The attention weights determine how much to focus on each token's value. This weighted sum produces the context vector z (shape T × d).
The scaling factor sqrt(d) prevents attention weights from becoming too large. As dimension d increases, dot product variance grows proportionally. Large values push softmax into saturation where gradients vanish. Division by sqrt(d) keeps weights reasonable regardless of dimension.
output and integration
The context vector z undergoes final transformation through output projection WO (shape d × D), restoring model dimension D. This output passes through residual connections and layer normalization before the feed-forward network.
The feed-forward component follows: Linear -> ReLU -> Linear with expansion factor F (typically 4D), creating transformations D -> FD -> D. This bottleneck forces the model to compress and decompress information, learning abstract features.
multi-head attention
Vaswani et al. introduced multi-head attention so models could "jointly attend to information from different representation sub-spaces at different positions." This uses H attention heads, each as a specialized expert examining input through different learned lenses.
Multi-head attention modifies projections to accommodate H heads:
- WQ: D × Hd (query projection for all heads)
- WK: D × Hd (key projection for all heads)
- WV: D × Hd (value projection for all heads)
Our projected representations become Q, K, V: T × Hd.
We perform attention H times in parallel, yielding H attention weight matrices (each T × T) and H context vectors zi (each T × d). These contexts are concatenated horizontally to form aggregate context z with shape T × Hd.
The output projection WO must have shape Hd × D to map concatenated context back to model dimension.
batch processing
Extending to batch processing with batch size B transforms matrices into tensors while preserving operations. Final tensor dimensions:
- X (input embeddings): B × T × D
- WQ, WK, WV (projections): D × Hd
- Q, K, V (projected representations): B × T × Hd
- A (attention weights): B × H × T × T
- z (concatenated context): B × T × Hd
- WO (output projection): Hd × D
encoder sum-up
The complete encoder layer integrates self-attention with feed-forward processing through residual connections and layer normalization. This combination, repeated across L encoder layers, forms the transformer's representational foundation.
Part 2 will examine decoder architecture, including cross-attention and masked attention.