transformer from scratch
tl;dr: implementing a transformer from scratch took a while and lots of pytorch documentation. but this was a good learning opportunity!
Callum McDougall's Transformer from Scratch template notebook is a great resource. I implemented LayerNorm, token embeddings, positional embeddings, attention, and used them to replicate GPT-2. There's some tests at every stage so you can ensure that you're doing it right. I'll go through what I learned, but I highly recommend trying it for yourself first if you haven't already because it's super illuminating!
LayerNorm
layernorm centers, norms, scales, and translates the residual stream.
in a bit more detail, we make the average 0, make the variance 1, scale it by a learned parameter gamma, and translate it by a learned parameter beta.
mathematically:
where is just some small number to avoid division by zero.
In the code, we take in a N x T x D tensor from the residual stream, average the D dimension into one value to get the mean, and calculate the variances of each D-dimensional vector. It now suffices to plug in the values in the formula.
class LayerNorm(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.w = nn.Parameter(t.ones(cfg.d_model))
self.b = nn.Parameter(t.zeros(cfg.d_model))
def forward(self, residual: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]: # N x T x D tensor
dims = (-1,) # dims is d_model, or D
mean = residual.mean(dim=dims, keepdim=True) # keepdim keeps the last dimension, so mean is a N x T x 1 tensor
var = residual.var(dim=dims, keepdim=True, correction=0) # this is also N x T x 1. correction = 0 means it's population variance, not bessel's
center_and_norm = (residual - mean)/(t.sqrt(var + 1e-5))
scale_and_translate = center_and_norm * self.w + self.b
return scale_and_translate
Embedding
We want to embed each token in a batch of input sequences into its embedding vector. We're given an embedding matrix, , of dimension x . The process is converting each input sequence into T one-hot matrices (1 for that token's index in the vocabulary, 0 if not) and then multiplying by . This diagram below should hopefully clarify it:
In the code, we do the exact thing, with the help of broadcasting to handle the B dimension.
class Embed(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.W_E = nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)))
nn.init.normal_(self.W_E, std=self.cfg.init_range)
def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
one_hot = nn.functional.one_hot(tokens, num_classes=self.cfg.d_vocab) # (B, T, d_vocab), int64 # pytorch functional helper for one-hots
one_hot = one_hot.float() # cast into floats for W_E
return one_hot @ self.W_E
rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)
There's a Pytorch-specific one liner solution. This intuitively makes sense, but I didn't know this existed until I implemented it the honest way.
return self.W_E[tokens]
Positional Embedding
My understanding of positional embedding was murky before, and this explanation from the notebook really helped: "positional embedding can also be thought of as a lookup table, but rather than the indices being our token IDs, the indices are just the numbers 0, 1, 2, ... seq_len-1 (i.e. the position indices of the tokens in the sequence)."
Since they're so similar to the embeddings we did above, I used a slightly modified version fo the one liner here.
class PosEmbed(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
nn.init.normal_(self.W_pos, std=self.cfg.init_range)
def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
seq_len = tokens.size(1)
idx = t.arange(seq_len)
return self.W_pos[idx]
rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)
Attention
Fortunately, this largely matched my previous understanding of attention! I forgot to include the biases before, but they don't change much, implementation-wise. Most of the time I spent on attention was learning einsum, which seems like an insanely convenient tool for generalized tensor operations.
In the code, we compute Q, K, V from x and the weight matrices WQ, WK, WV. Then, we multiply Q by K transpose (notice the dimension setup in this multiplication: it serves as a transpose even if we don't explicitly transpose K).
The causal mask was new for me: I finally get why it's a triangle of 1s above the diagonal and 0s everywhere else. It sounds harder and looks harder than it is!
After applying the mask, we softmax over the last dimension of the attention scores, which are the key positions. This is because for every fixed (N, H, query position) slice, we want a probability distribution over the keys.
class Attention(nn.Module):
IGNORE: Float[Tensor, ""]
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
nn.init.normal_(self.W_Q, std=self.cfg.init_range)
nn.init.normal_(self.W_K, std=self.cfg.init_range)
nn.init.normal_(self.W_V, std=self.cfg.init_range)
nn.init.normal_(self.W_O, std=self.cfg.init_range)
self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device=device))
def forward(
self, normalized_resid_pre: Float[Tensor, "batch posn d_model"]
) -> Float[Tensor, "batch posn d_model"]:
x = normalized_resid_pre
N, T, D = x.shape
H, d_head = self.cfg.n_heads, self.cfg.d_head
q = t.einsum('NTD, HDk -> NHTk', x, self.W_Q) + self.b_Q[None, :, None, :]
k = t.einsum('NTD, HDk -> NHTk', x, self.W_K) + self.b_K[None, :, None, :]
v = t.einsum('NTD, HDk -> NHTk', x, self.W_V) + self.b_V[None, :, None, :]
scores = t.einsum('NHTk, NHSk -> NHTS', q, k) / math.sqrt(d_head)
scores = self.apply_causal_mask(scores)
attn = scores.softmax(dim=-1)
z = t.einsum('NHTS, NHSk -> NHTk', attn, v)
out = t.einsum('NHTk, HkD -> NTD', z, self.W_O) + self.b_O
return out
def apply_causal_mask(
self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
'''
Applies a causal mask to attention scores, and returns masked scores.
'''
_, _, T_q, T_k = attn_scores.shape
mask = t.triu(t.ones(T_q, T_k, dtype=t.bool, device=attn_scores.device), diagonal=1) # upper half triangle of 1s without the diagonal.
attn_scores.masked_fill_(mask, self.IGNORE)
return attn_scores
rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["normalized", 0, "ln1"]
MLP
Two layers: we go into the hidden, we hit a GELU, and step out.
class MLP(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
self.b_out = nn.Parameter(t.zeros((cfg.d_model)))
nn.init.normal_(self.W_in, std=self.cfg.init_range)
nn.init.normal_(self.W_out, std=self.cfg.init_range)
def forward(
self, normalized_resid_mid: Float[Tensor, "batch posn d_model"]
) -> Float[Tensor, "batch posn d_model"]:
hidden = normalized_resid_mid @ self.W_in + self.b_in
hidden = gelu_new(hidden)
out = hidden @ self.W_out + self.b_out
return out
rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["normalized", 0, "ln2"]
Transformer Block
This was nice for understanding skip connections. We send the residuals through the attention layer and the MLP, but we combine these with the original residuals and residuals that have only been through the attention layer. This again reinforces the idea that the residual stream is the main communication channel of the transformer: layers write and read from it.
class TransformerBlock(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.ln1 = LayerNorm(cfg)
self.attn = Attention(cfg)
self.ln2 = LayerNorm(cfg)
self.mlp = MLP(cfg)
def forward(
self, resid_pre: Float[Tensor, "batch position d_model"]
) -> Float[Tensor, "batch position d_model"]:
norm1 = self.ln1(resid_pre)
attn_out = self.attn(norm1)
resid_mid = resid_pre + attn_out
norm2 = self.ln2(resid_mid)
mlp_out = self.mlp(norm2)
resid_post = resid_mid + mlp_out
return resid_post
rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])
Unembedding
Weights and biases!
class Unembed(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.W_U = nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)))
nn.init.normal_(self.W_U, std=self.cfg.init_range)
self.b_U = nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False))
def forward(
self, normalized_resid_final: Float[Tensor, "batch position d_model"]
) -> Float[Tensor, "batch position d_vocab"]:
return normalized_resid_final @ self.W_U + self.b_U
rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"]
Demo Transformer
It's nice to see this in code: we start with a batch of sequences of tokens. We add their token embeddings and positional embeddings together, we process them through some number of transformer blocks, and unembed them into logits.
class DemoTransformer(nn.Module):
def __init__(self, cfg: Config):
super().__init__()
self.cfg = cfg
self.embed = Embed(cfg)
self.pos_embed = PosEmbed(cfg)
self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
self.ln_final = LayerNorm(cfg)
self.unembed = Unembed(cfg)
def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_vocab"]:
tok_emb = self.embed(tokens)
pos_emb = self.pos_embed(tokens)
resid = tok_emb + pos_emb
for block in self.blocks:
resid = block(resid)
resid = self.ln_final(resid)
logits = self.unembed(resid)
return logits
rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens
onwards
I think a gears-level model of the transformer is an amazing thing to have; you can't go wrong, and it's rewarding in and of itself. but particularly, I want to have it as a tool for mechanistic interpretability. from what I've read so far, understanding the transformer well is practically essential. At worst, great to know, and at best, fundamental.