transformer from scratch

tl;dr: implementing a transformer from scratch took a while and lots of pytorch documentation. but this was a good learning opportunity!

Callum McDougall's Transformer from Scratch template notebook is a great resource. I implemented LayerNorm, token embeddings, positional embeddings, attention, and used them to replicate GPT-2. There's some tests at every stage so you can ensure that you're doing it right. I'll go through what I learned, but I highly recommend trying it for yourself first if you haven't already because it's super illuminating!

LayerNorm

layernorm centers, norms, scales, and translates the residual stream.

in a bit more detail, we make the average 0, make the variance 1, scale it by a learned parameter gamma, and translate it by a learned parameter beta.

mathematically:

\ ext{LayerNorm}(x)=\ rac{x-\mathbb{E}[x]}{\sqrt{\ ext{Var}[x]+\epsilon}}\cdot \gamma +\eta

where ϵ\epsilon is just some small number to avoid division by zero.

In the code, we take in a N x T x D tensor from the residual stream, average the D dimension into one value to get the mean, and calculate the variances of each D-dimensional vector. It now suffices to plug in the values in the formula.

class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(self, residual: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]: # N x T x D tensor
        dims = (-1,) # dims is d_model, or D
        mean = residual.mean(dim=dims, keepdim=True) # keepdim keeps the last dimension, so mean is a N x T x 1 tensor
        var = residual.var(dim=dims, keepdim=True, correction=0) # this is also N x T x 1. correction = 0 means it's population variance, not bessel's
        center_and_norm = (residual - mean)/(t.sqrt(var + 1e-5))
        scale_and_translate = center_and_norm * self.w + self.b
        return scale_and_translate

Embedding

We want to embed each token in a batch of input sequences into its embedding vector. We're given an embedding matrix, WEW_E, of dimension dvocabd_{vocab} x dmodeld_{model}. The process is converting each input sequence into T one-hot matrices (1 for that token's index in the vocabulary, 0 if not) and then multiplying by WEW_E. This diagram below should hopefully clarify it:

In the code, we do the exact thing, with the help of broadcasting to handle the B dimension.

class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        one_hot = nn.functional.one_hot(tokens, num_classes=self.cfg.d_vocab)  # (B, T, d_vocab), int64 # pytorch functional helper for one-hots
        one_hot = one_hot.float() # cast into floats for W_E
        return one_hot @ self.W_E

rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)

There's a Pytorch-specific one liner solution. This intuitively makes sense, but I didn't know this existed until I implemented it the honest way.

return self.W_E[tokens]

Positional Embedding

My understanding of positional embedding was murky before, and this explanation from the notebook really helped: "positional embedding can also be thought of as a lookup table, but rather than the indices being our token IDs, the indices are just the numbers 0, 1, 2, ... seq_len-1 (i.e. the position indices of the tokens in the sequence)."

Since they're so similar to the embeddings we did above, I used a slightly modified version fo the one liner here.

class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
        seq_len = tokens.size(1)
        idx = t.arange(seq_len)
        return self.W_pos[idx]

rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

Attention

Fortunately, this largely matched my previous understanding of attention! I forgot to include the biases before, but they don't change much, implementation-wise. Most of the time I spent on attention was learning einsum, which seems like an insanely convenient tool for generalized tensor operations.

In the code, we compute Q, K, V from x and the weight matrices WQ, WK, WV. Then, we multiply Q by K transpose (notice the dimension setup in this multiplication: it serves as a transpose even if we don't explicitly transpose K).

The causal mask was new for me: I finally get why it's a triangle of 1s above the diagonal and 0s everywhere else. It sounds harder and looks harder than it is!

After applying the mask, we softmax over the last dimension of the attention scores, which are the key positions. This is because for every fixed (N, H, query position) slice, we want a probability distribution over the keys.

class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device=device))

    def forward(
        self, normalized_resid_pre: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        x = normalized_resid_pre
        N, T, D = x.shape
        H, d_head = self.cfg.n_heads, self.cfg.d_head

        q = t.einsum('NTD, HDk -> NHTk', x, self.W_Q) + self.b_Q[None, :, None, :]
        k = t.einsum('NTD, HDk -> NHTk', x, self.W_K) + self.b_K[None, :, None, :]
        v = t.einsum('NTD, HDk -> NHTk', x, self.W_V) + self.b_V[None, :, None, :]

        scores = t.einsum('NHTk, NHSk -> NHTS', q, k) / math.sqrt(d_head)
        scores = self.apply_causal_mask(scores)
        attn = scores.softmax(dim=-1)

        z = t.einsum('NHTS, NHSk -> NHTk', attn, v)
        out = t.einsum('NHTk, HkD -> NTD', z, self.W_O) + self.b_O
        return out

    def apply_causal_mask(
        self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]
    ) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
        '''
        Applies a causal mask to attention scores, and returns masked scores.
        '''
        _, _, T_q, T_k = attn_scores.shape
        mask = t.triu(t.ones(T_q, T_k, dtype=t.bool, device=attn_scores.device), diagonal=1) # upper half triangle of 1s without the diagonal.
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["normalized", 0, "ln1"]

MLP

Two layers: we go into the hidden, we hit a GELU, and step out.

class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(
        self, normalized_resid_mid: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
        hidden = normalized_resid_mid @ self.W_in + self.b_in
        hidden = gelu_new(hidden)
        out = hidden @ self.W_out + self.b_out
        return out

rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["normalized", 0, "ln2"]

Transformer Block

This was nice for understanding skip connections. We send the residuals through the attention layer and the MLP, but we combine these with the original residuals and residuals that have only been through the attention layer. This again reinforces the idea that the residual stream is the main communication channel of the transformer: layers write and read from it.

class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self, resid_pre: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_model"]:
        norm1      = self.ln1(resid_pre) 
        attn_out   = self.attn(norm1)  
        resid_mid  = resid_pre + attn_out

        norm2      = self.ln2(resid_mid)  
        mlp_out    = self.mlp(norm2)  
        resid_post = resid_mid + mlp_out 

        return resid_post

rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

Unembedding

Weights and biases!

class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, normalized_resid_final: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_vocab"]:
        return normalized_resid_final @ self.W_U + self.b_U

rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"]

Demo Transformer

It's nice to see this in code: we start with a batch of sequences of tokens. We add their token embeddings and positional embeddings together, we process them through some number of transformer blocks, and unembed them into logits.

class DemoTransformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_vocab"]:
        tok_emb = self.embed(tokens)
        pos_emb = self.pos_embed(tokens)
        resid = tok_emb + pos_emb

        for block in self.blocks:
            resid = block(resid)  

        resid = self.ln_final(resid)
        logits = self.unembed(resid)

        return logits

rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens

onwards

I think a gears-level model of the transformer is an amazing thing to have; you can't go wrong, and it's rewarding in and of itself. but particularly, I want to have it as a tool for mechanistic interpretability. from what I've read so far, understanding the transformer well is practically essential. At worst, great to know, and at best, fundamental.