Building GPT from Scratch
View on GitHubTable of Contents
Over the past few months, I've been doing a deep dive on the transformer. In this project, I've implemented a modern LLaMA-style version of GPT. In this post, I'll share some of the key insights and challenges I encountered, specifically focusing on pretraining and model architecture.
Architecture Overview
The core of the model follows the transformer architecture introduced in "Attention Is All You Need", but with several modern improvements:
- Multi-head self-attention with scaled dot-product attention
- RoPE (Rotary Position Embeddings) instead of learned positional encodings
- SwiGLU and gated linear units instead of ReLU
- RMSNorm instead of Layer Normalization
- FlashAttention for memory-efficient attention computation
Multi-Head Attention
The attention mechanism is the heart of the transformer. This part of the model remains nearly identical to the original transformer. For the sake of completeness, I'll discuss multi-head attention here.
The key idea of this component is to enhance the semantic representation of the input by allowing the model to interact with past tokens. Attention is computed by the following formula:
In the attention mechanism, $Q$ (Query) represents the search intent, $K$ (Key) acts as the indexing attribute, and $V$ (Value) contains the actual information. This effectively functions as a soft dictionary lookup where the dot product $QK^T$ determines the relevance weight for each token's value.
The raw scores are scaled by $1/\sqrt{d_k}$ because the dot product involves summing $d_k$ pairs of independent random variables, which naturally scales the variance by $d_k$. Since variance scales quadratically ($\text{Var}(aX) = a^2\text{Var}(X)$), dividing the logits by $\sqrt{d_k}$ perfectly counteracts this growth to normalize the variance back to 1, preventing softmax saturation.
Finally, Multi-Head Attention splits this process across parallel subspaces, allowing the model to simultaneously capture distinct linguistic relationships—such as local positioning, syntactic dependencies, and semantic contexts—without averaging conflicting signals into a single, muddy representation.
class MultiHeadAttention(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.n_head = config.n_head
self.d_model = config.d_model
self.d_head = config.d_head # dimension per head
# Project to q, k, v (full projection)
self.q = nn.Linear(self.d_model, self.d_model, bias=False)
self.k = nn.Linear(self.d_model, self.d_model, bias=False)
self.v = nn.Linear(self.d_model, self.d_model, bias=False)
# Output projection back to d_model
self.Wo = nn.Linear(self.d_model, self.d_model, bias=False)
# RoPE
self.rope = RotaryEmbedding(config.d_head)
def forward(self, x):
B, T, D = x.shape
# Compute q,k,v
q = self.q(x) # (B, T, D)
k = self.k(x)
v = self.v(x)
# Reshape into heads: (B, n_head, T, d_head)
q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2)
k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2)
v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2)
q, k = self.rope(q, k) # apply RoPE after reshaping
# PyTorch 2.0+ built-in Flash Attention / SDPA
out = F.scaled_dot_product_attention(
q, k, v,
is_causal=True # autoregressive mask
) # shape: (B, n_head, T, d_head)
# Merge heads back: (B, T, D)
out = out.transpose(1, 2).contiguous().view(B, T, D)
return self.Wo(out)
RMSNorm
Instead of Layer Normalization, I used RMSNorm (Root Mean Square Layer Normalization), which is simpler and more efficient.
Layer Normalization (LN)
$$y_i = \gamma \hat{x}_i + \beta$$
where $\mu = \frac{1}{n}\sum_{i=1}^{n} x_i$
and $\sigma^2 = \frac{1}{n}\sum_{i=1}^{n} (x_i - \mu)^2$
RMSNorm
$$y_i = \gamma \hat{x}_i$$
where $\text{RMS}(x) = \sqrt{\frac{1}{n}\sum_{i=1}^{n} x_i^2}$
What are the main differences between LN and RMSNorm? Layer Normalization tries to keep the data at zero mean and unit variance, while RMSNorm doesn't specifically do this. Precisely, RMSNorm normalizes by the root mean square, meaning that it forces $\mathbb{E}[x^2] = 1$.
Isn't it bad to not have zero mean and unit variance? It turns out that centering does almost nothing for training stability. Centering the mean basically means that we're shifting the data, but transformers are shift invariant. Additionally, with SwiGLU replacing ReLU, the idea of having a non-zero mean is even less important. RMSNorm is also more efficient to compute because it doesn't require the mean and variance calculations, so if it doesn't make a difference, the answer is clear.
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
# learnable scale parameter
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# Compute RMS along last dimension
# rms = sqrt(mean(x^2))
rms = x.norm(2, dim=-1, keepdim=True) / math.sqrt(x.size(-1))
return self.weight * (x / (rms + self.eps))
RoPE Positional Embeddings
Instead of learned positional embeddings, I implemented Rotary Position Embeddings (RoPE). RoPE is a type of positional encoding that encodes relative position information directly into the attention, as opposed to absolute position. The key idea of RoPE is that position embeddings are a function of the relative position between two tokens.
How do we apply rotation to our vectors? We construct our rotation matrix by splitting our $d$-dimensional vector into 2-dimensional pairs and then applying the rotation matrix to each of them. Another important property is that we rotate in multiple ways based on the frequency of the position. This is because we want to capture both long-range and short-range dependencies.
For example, if we rotate little when positions are close, it would have a hard time capturing the difference. If we rotate a lot, we would be able to capture close-range dependencies better, but our long-range dependencies would be off. Therefore, we rotate in multiple ways based on the frequency of the position so we can model both long-range and short-range dependencies.
class RotaryEmbedding(nn.Module):
"""
Standard RoPE (Rotary Positional Embedding)
Used in LLaMA / Qwen / Mistral (without NTK scaling).
"""
def __init__(self, dim, base=10000):
"""
dim: rotary dimension (must be even), head dimension
base: frequency base (default 10,000, same as GPT/LLaMA)
"""
super().__init__()
assert dim % 2 == 0
self.dim = dim
# Create inverse frequency: shape [dim/2]
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq) # stays on same device as model
def _compute_angles(self, seq_len, device, dtype):
"""
Compute cos/sin for all positions.
returns: cos, sin with shape [seq_len, dim/2]
"""
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) # pos * inv_freq
cos = freqs.cos().to(dtype)
sin = freqs.sin().to(dtype)
return cos, sin
def apply_rotary(self, x, cos, sin):
"""
Apply RoPE to x.
x: [batch, heads, seq, dim]
cos: [seq, dim/2]
sin: [seq, dim/2]
"""
# Split even / odd
x_part = x[..., ::2]
y_part = x[..., 1::2]
cos = cos[None, None, :, :]
sin = sin[None, None, :, :]
x_rot = x_part * cos - y_part * sin
y_rot = x_part * sin + y_part * cos
x_out = torch.stack((x_rot, y_rot), dim=-1)
return x_out.flatten(-2)
def forward(self, q, k):
seq_len = q.size(-2)
cos, sin = self._compute_angles(seq_len, q.device, q.dtype)
return self.apply_rotary(q, cos, sin), self.apply_rotary(k, cos, sin)
SwiGLU Feed-Forward Network
The feed-forward network uses SwiGLU (Swish-Gated Linear Unit) activation instead of ReLU. How does SwiGLU work? It's a combination of a linear layer and a gate. The gate uses the SiLU (Sigmoid Linear Unit) activation function applied to one linear transformation, which is then multiplied element-wise with another linear transformation.
Why do we prefer SwiGLU over ReLU? ReLU is basically $\max(0, x)$, and the problem with this is that all values $< 0$ are set to 0, causing neurons to become "dead." SwiGLU avoids this problem by never quite setting the values to 0, allowing gradients to flow through even for negative inputs.
In addition, SwiGLU acts as a gate for the linear layer, meaning that it can control what information we want to pass through the network. We can think of $W_1$ as the activation gate and $W_3$ as the regular linear layer. $W_2$ is the output layer primarily used for (1) dimensionality reduction and (2) information retrieval, similar to the value projection in attention.
class FeedForward(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
hidden = 4 * config.d_model # standard hidden expansion
# SwiGLU = silu(W1x) * (W3x)
self.w1 = nn.Linear(config.d_model, hidden, bias=False)
self.w3 = nn.Linear(config.d_model, hidden, bias=False)
self.w2 = nn.Linear(hidden, config.d_model, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
Key Learnings & Reflection
This project taught me a lot about the practical challenges of training language models:
- Architecture choices matter: Modern improvements like RMSNorm, RoPE, and SwiGLU make a significant difference in both training stability and final performance.
- Memory efficiency is paramount: FlashAttention was essential for training longer sequences and managing GPU memory constraints.
- Implementation details: Small details like proper initialization, gradient clipping, and learning rate scheduling can make or break training stability.
- Pretraining is complex: Managing large-scale data pipelines, distributed training, and monitoring training progress requires careful engineering.
Building this from scratch has given me a deep appreciation for the engineering and research that goes into modern language models. Every component, from the attention mechanism to the training infrastructure, requires careful design and implementation.