enbao

generalist; freedom and light.

understanding attention

We examine the transformer architecture from a technical perspective, focusing on the mathematical foundations of attention. The original transformer (Vaswani et al., 2017) comprises six encoder layers, six decoder layers, a linear projection, and a softmax activation. Unlike sequential RNNs, transformers enable massive parallelization by expressing operations as batched GEMMs, achieving superior GPU utilization.

Each encoder layer consists of self-attention followed by a feed-forward network. We examine each component systematically.

self-attention

For simplicity, we begin with a single sequence (batch size B = 1). The input sentence undergoes tokenization (BPE, SentencePiece) yielding T tokens. Each token becomes an embedding of dimension D (typically 512) through a learned embedding matrix, resulting in input embeddings X with shape T × D.

The attention mechanism transforms embeddings through three learned projection matrices: query projection WQ (shape D × d), key projection WK (shape D × d), and value projection WV (shape D × d), where d represents the attention dimension. The dimension d may equal D for full-rank attention, or be much smaller for bottleneck attention, which compresses information to reduce memory and computation.

These projections yield:

  • Q = X × WQ (queries, shape T × d)
  • K = X × WK (keys, shape T × d)
  • V = X × WV (values, shape T × d)

the core computation

The attention mechanism computes context through:

Attention(Q,K,V) = softmax((QKT)/sqrt(d))V

The matrix multiplication QKT computes all pairwise dot products between query and key vectors. This yields attention weights A with shape T × T. This pairwise computation makes attention quadratic in memory and computation - we calculate interactions between every token pair.

Think of attention like a database lookup. The query asks "what information do I need?" The key advertises "what information do I contain?" The value provides "here's my actual content." When computing QKT, we ask: "how well does what I seek match what you offer?" The attention weights determine how much to focus on each token's value. This weighted sum produces the context vector z (shape T × d).

The scaling factor sqrt(d) prevents attention weights from becoming too large. As dimension d increases, dot product variance grows proportionally. Large values push softmax into saturation where gradients vanish. Division by sqrt(d) keeps weights reasonable regardless of dimension.

output and integration

The context vector z undergoes final transformation through output projection WO (shape d × D), restoring model dimension D. This output passes through residual connections and layer normalization before the feed-forward network.

The feed-forward component follows: Linear -> ReLU -> Linear with expansion factor F (typically 4D), creating transformations D -> FD -> D. This bottleneck forces the model to compress and decompress information, learning abstract features.

multi-head attention

Vaswani et al. introduced multi-head attention so models could "jointly attend to information from different representation sub-spaces at different positions." This uses H attention heads, each as a specialized expert examining input through different learned lenses.

Multi-head attention modifies projections to accommodate H heads:

  • WQ: D × Hd (query projection for all heads)
  • WK: D × Hd (key projection for all heads)
  • WV: D × Hd (value projection for all heads)

Our projected representations become Q, K, V: T × Hd.

We perform attention H times in parallel, yielding H attention weight matrices (each T × T) and H context vectors zi (each T × d). These contexts are concatenated horizontally to form aggregate context z with shape T × Hd.

The output projection WO must have shape Hd × D to map concatenated context back to model dimension.

batch processing

Extending to batch processing with batch size B transforms matrices into tensors while preserving operations. Final tensor dimensions:

  • X (input embeddings): B × T × D
  • WQ, WK, WV (projections): D × Hd
  • Q, K, V (projected representations): B × T × Hd
  • A (attention weights): B × H × T × T
  • z (concatenated context): B × T × Hd
  • WO (output projection): Hd × D

encoder sum-up

The complete encoder layer integrates self-attention with feed-forward processing through residual connections and layer normalization. This combination, repeated across L encoder layers, forms the transformer's representational foundation.

Part 2 will examine decoder architecture, including cross-attention and masked attention.