Karpathy Series - Let's reproduce GPT-2
- Project goal: reproduce GPT-2 124M model
- Hugging Face conversion used to obtain PyTorch weights and state dict
- Token and positional embeddings shapes and semantics
- Verify generation from released GPT-2 weights
- GPT-2 architecture differences from the original Transformer
- Block-level computation: pre-normalization, residuals, attention vs MLP
- MLP design and the GELU/Gated nonlinearity choice
- Multi-head self-attention implementation and tensor reshaping tricks
- Loading Hugging Face weights into custom GPT class and TF->PyTorch transpositions
- Forward pass and logits shape semantics
- Sampling loop, top-k sampling, and pipeline differences
- Device autodetection and cross-backend compatibility (CUDA, MPS, CPU)
- Batch shaping for Transformer training: converting long streams to B x T tensors
- Loss computation for next-token prediction and flattening for CrossEntropy
- Simple optimization loop and debugging overfitting a single batch
- Simple sequential data loader (sharded chunk iteration)
- Weight tying of token embeddings and LM head (shared embedding)
- Initialization conventions: fixed std and residual scaling for deep stacks
- Floating-point precision tradeoffs and TF32 impact
- Mixed precision (BF16) and autocast benefits
- torch.compile: reducing Python overhead and enabling kernel fusion
- FlashAttention: fused attention kernel that avoids materializing full attention matrices
- Padding vocabulary and tensor sizes to ‘nice’ multiples for CUDA kernels
- Optimizer choices: AdamW, fused optimizer kernels, and parameter grouping
- Gradient clipping and learning-rate scheduling (cosine decay with warmup)
- Gradual batch scaling considerations and practical omission
- Gradient accumulation semantics and correct loss scaling
- Distributed Data Parallel (DDP) fundamentals and synchronization control
- Training data choices and high-quality CommonCrawl subsets (FineWeb/edu)
- Training runs, throughput tuning and producing checkpoints
- Downstream evaluation: validation loss, H-SWAG implementation and caveats
- Production considerations: logging, checkpoints, alternative C/CUDA implementations and summary
Project goal: reproduce GPT-2 124M model
Objective: reproduce the GPT-2 124M model (the smallest GPT-2 checkpoint) by implementing the model, loading reference weights, and training from scratch until validation and downstream metrics match or exceed the released checkpoint.
The reproduction relies on published papers, released weights, and Hugging Face conversions to ensure architectural fidelity, while using modern tooling for training and evaluation.
Target model specs to match exactly:
-
12 Transformer layers
-
Hidden dimension 768
-
Vocabulary consistent with the GPT-2 tokenizer
Strict attention to hyperparameters and initialization is required to replicate performance.
Hugging Face conversion used to obtain PyTorch weights and state dict
Hugging Face Transformers provides a PyTorch implementation and converted state dictionaries that map the original TensorFlow GPT-2 weights into PyTorch tensors.
Practical uses:
- Load a HF GPT2LMHeadModel and inspect its state_dict to learn parameter names and shapes.
- The state_dict reveals token embeddings, positional encodings, per-layer attention and MLP weights, and LM head weights.
These raw tensors guide:
-
Parameter initialization for a from-scratch model
-
Key/name mapping and any necessary transpositions from original formats
Using the HF model simplifies access to the canonical 124M parameters while enabling reimplementation and verification.
Token and positional embeddings shapes and semantics
Token embeddings form a matrix of shape (vocab_size x d_model) — for GPT-2 this is typically 50257 x 768, giving a 768-dimensional vector per token.
Positional embeddings are a learned table of length equal to the maximum context (e.g., 1024 x 768) and are added to token embeddings before the Transformer blocks.
Key points:
- Embeddings produce distributed representations for tokens and absolute positions.
- Per-row and per-column patterns can show sinusoidal-like structure and channel-specific activations that arise from optimization (not explicit sinusoidal initialization).
- Understanding these shapes is essential for exact parameter mapping and architecture reconstruction.
Verify generation from released GPT-2 weights
Loading the official GPT-2 124M weights via Hugging Face and running the text-generation pipeline provides a functional correctness check.
What to expect and check:
- Sampled continuations should be coherent and consistent with a pretrained LM.
- Differences in outputs can come from RNG state, sampling defaults (top-k/top-p), and tokenization choices.
Why this matters:
- Coherent generations confirm correct weight loading and token/position mapping.
- This establishes a baseline target for models retrained from scratch and provides an empirical comparison point.
GPT-2 architecture differences from the original Transformer
GPT-2 is a decoder-only Transformer with two main departures from the original “Attention Is All You Need” design:
-
Pre-norm vs post-norm: GPT-2 reshuffles where LayerNorm is placed (pre-norm layout).
-
Extra final LayerNorm: an additional final layer norm appears before the LM head.
Other structural notes:
- The model omits encoder-decoder cross-attention — it contains only masked self-attention and feed-forward sublayers repeated for the specified depth (e.g., 12 layers).
Reimplementations must:
- Mirror these structural differences exactly
- Follow the naming/schema used by reference implementations (e.g., Hugging Face) for exact parameter correspondence and easy weight loading.
Block-level computation: pre-normalization, residuals, attention vs MLP
Each Transformer block is composed of:
- A pre-normalized attention sublayer with residual connection
- A pre-normalized MLP (feed-forward) sublayer with residual connection
Functional roles:
-
Attention is a reduce operation: weighted aggregation across tokens enabling inter-token communication.
-
MLP is a per-token map: processes each token independently to transform its representation.
Implementation implications:
- The pre-norm layout places LayerNorm before the sublayer transforms, which affects gradient flow and training dynamics.
- Careful implementation is required to match GPT-2’s training behavior and stability.
MLP design and the GELU/Gated nonlinearity choice
The MLP (feed-forward) consists of:
- Two linear projections with a nonlinearity between them.
Activation detail:
- GPT-2 uses a GELU variant (an approximate GELU historically used in TensorFlow for performance).
-
GELU provides non-zero gradients in regions where ReLU would be exactly zero, mitigating the dead-ReLU issue and improving optimization stability.
Reproduction guidance:
- Use the same approximate activation (or exact GELU if hardware allows) because activation asymptotics and gradient behavior impact training dynamics and final performance.
Multi-head self-attention implementation and tensor reshaping tricks
Multi-headed attention implementation steps:
- Linearly project input to combined Q/K/V tensors.
- Reshape to separate head and batch dimensions.
- Perform batched scaled dot-product attention with causal masking.
- Concatenate head outputs and apply a final linear projection.
Performance notes:
- Efficient PyTorch implementations treat the head dimension as an additional batch dimension to use parallel kernels and reduce Python overhead.
- Ensure mathematical equivalence to per-head implementations is preserved.
- Matching naming and parameter layout to Hugging Face conventions simplifies weight transfer and ensures functional equivalence.
Loading Hugging Face weights into custom GPT class and TF->PyTorch transpositions
Transferring parameters from Hugging Face (or TF-origin) to a from-scratch PyTorch GPT class requires careful mapping and conversion:
Recommended procedure:
- Iterate over HF state_dict keys and map them to local module names.
- Optionally ignore non-parameter buffers (e.g., static causal mask buffers).
- Identify matrices requiring transpose due to TF-to-PyTorch layout differences and apply transposition.
- Verify shape equality after mapping.
Encapsulation:
- Implement a robust from_pretrained class method that performs these conversions and returns a PyTorch model whose state tensors exactly match the reference numerics for generation and evaluation.
Forward pass and logits shape semantics
The forward pass semantics:
- Input: token indices shaped (B x T).
- Output: logits shaped (B x T x V), where V is vocabulary size.
Computation flow:
- Sum token and positional embeddings (positional broadcast across batch rows).
- Pass through the Transformer blocks.
- Apply the final layer norm and the LM head linear projection to produce logits.
- Convert logits to probabilities via softmax for sampling or use directly for cross-entropy loss.
Implementation must ensure:
- Correct tensor shapes and broadcasting.
- Device-consistent tensors to avoid runtime errors during forward and loss computation.
Sampling loop, top-k sampling, and pipeline differences
Autoregressive generation pattern:
- Loop and append one sampled token at a time, using only last-step logits to reduce computation.
- Apply top-k filtering (e.g., k=50), renormalize, and sample to avoid very rare tokens and improve coherence.
Practical tips:
- Use torch.no_grad to avoid saving intermediate tensors for backward passes.
- Carefully manage RNG seeds and torch.Generator objects to isolate sampling randomness from training RNG state.
- Differences in HF pipeline defaults (top-k, top-p, temperature) can cause the same seed to produce different outputs across implementations.
Device autodetection and cross-backend compatibility (CUDA, MPS, CPU)
Training code should detect and use available devices (CUDA GPU, Apple MPS, or CPU) and move model tensors and inputs to the same device to avoid mismatches.
Best practices:
- Use device-aware tensor creation (e.g., torch.arange(…, device=idx.device)) so forward logic stays device-agnostic.
- When GPUs are unavailable, code should still run on CPU or MPS for debugging (slower but functional).
- Log the chosen device and adapt batch sizes to fit memory constraints for reproducible behavior across hardware backends.
Batch shaping for Transformer training: converting long streams to B x T tensors
Construct training batches from a token stream by reshaping contiguous token arrays into (B x T) tensors, where each row is a context sequence up to block size T.
Label construction:
- Load an extra token per row (B*T + 1) and slice into inputs X (all except last) and targets Y (all except first) so each input position has a next-token label.
Benefits:
- Efficient batched training with B independent sequences for parallel computation.
- Ensures last-token targets are present for loss computation.
Loss computation for next-token prediction and flattening for CrossEntropy
For language-model cross-entropy:
- Flatten logits from (B x T x V) to (BT x V)** and targets from **(B x T)** to **(BT) because PyTorch’s F.cross_entropy expects 2D logits and 1D targets.
Forward API:
- When targets are provided, return both logits and loss from forward to centralize computation.
Sanity checks:
- Initial random-model loss approximates -log(1/V) (≈ ln(V); e.g., ~10.8 for V≈50k), which helps validate initialization.
Care:
- Correct flattening, masking out-of-range positions (if used), and reduction semantics are crucial to stable training.
Simple optimization loop and debugging overfitting a single batch
A minimal training loop:
- Create an optimizer (recommend AdamW).
- Zero gradients.
- Compute loss and call loss.backward().
- Call optimizer.step() to update parameters.
Debugging checks:
- Verify the model can overfit a small batch — this confirms forward/loss/backward pathways are correct.
- Start without gradient accumulation to simplify debugging; initial LR defaults (e.g., 3e-4) are reasonable for quick overfit tests.
Watch for device mismatches — ensure all buffers and tensors live on the same device.
Simple sequential data loader (sharded chunk iteration)
A minimal data loader approach:
- Iterate through a tokenized corpus in fixed-size chunks of B*T tokens.
- Return (X, Y) pairs by advancing a read pointer by B*T each time and loop to the start when exhausted.
Properties:
- Deterministic, epoch-based iteration without replacement until a full pass completes.
- Sharding the corpus into fixed-size shards simplifies I/O and parallel processing.
Extensions:
- For multi-epoch training, shuffle document order and permute shards per epoch to avoid ordering effects.
Weight tying of token embeddings and LM head (shared embedding)
Weight tying reuses the same matrix for input token embeddings and the output LM head projection (pre-softmax), yielding:
- Substantial parameter savings and an inductive bias tying input/output representations.
Implementation in PyTorch:
- Assign the LM head weight tensor to reference the embedding weight tensor (share the same storage pointer) so gradients accumulate into the same parameter.
Effects:
- Reduces parameter count (e.g., ~40M saved in the 124M model).
- Often improves sample efficiency and generalization.
Initialization conventions: fixed std and residual scaling for deep stacks
GPT-2 reference initializations and variance control:
- Most linear weights use normal initialization with std=0.02 and zero biases.
- Embedding stds: e.g., token embeddings 0.02, position embeddings sometimes 0.01 in some codebases.
Residual-sum scaling:
- Scale-down residual block weights by 1/sqrt(n) to compensate for variance growth in a sum-of-residuals architecture.
- Implement by multiplying initialization std by 1/sqrt(2*L) where L is the number of Transformer layers (two residual contributions per layer).
Why this matters:
- These initialization details stabilize optimization and mirror original training dynamics.
Floating-point precision tradeoffs and TF32 impact
Modern GPU precisions overview:
-
FP32: baseline full precision, stable but limited throughput.
-
TF32: a cropped-mantissa FP32 variant on Ampere that executes faster on tensor cores by truncating some mantissa bits while preserving exponent range.
Practical notes:
- Enable TF32 in PyTorch (e.g., torch.set_float32_matmul_precision(‘high’)) to transparently use faster tensor-core kernels.
- Gains depend on whether workloads are compute-bound vs memory-bound; expect modest numerical impact but potentially large throughput improvements.
Mixed precision (BF16) and autocast benefits
BFloat16 (BF16) reduces storage and memory transfer cost by truncating mantissa while preserving exponent range, avoiding many of FP16’s numerical issues.
Best practice:
- Use torch.autocast to lower eligible ops (matmuls, convolutions) to BF16 while keeping sensitive ops (layernorm, softmax, loss) in FP32.
Requirements and benefits:
- Requires hardware support (e.g., Ampere GPUs).
- Often obviates explicit gradient scaling and yields substantial throughput and memory improvements when used carefully.
torch.compile: reducing Python overhead and enabling kernel fusion
torch.compile (PyTorch compilation) analyzes the model graph and generates optimized kernels to reduce Python overhead and enable operator/kernel fusion.
Benefits:
- Removes repeated operator dispatch, fuses elementwise sequences, and can reduce global memory round-trips.
- Typically yields multi-fold speedups for repeated training iterations after an upfront compilation cost.
FlashAttention: fused attention kernel that avoids materializing full attention matrices
FlashAttention is a fused GPU kernel for masked scaled-dot-product attention that avoids materializing the full T x T attention matrix in HBM.
How it works and why it helps:
- Uses an online softmax normalization trick and block-wise accumulation to reduce expensive HBM reads/writes.
- Orchestrates shared-memory usage and scaling to compute attention in streaming blocks.
- Preserves functional semantics for causal attention while lowering memory footprint and improving runtime compared to naive attention implementations.
Padding vocabulary and tensor sizes to ‘nice’ multiples for CUDA kernels
Hardware kernel considerations:
- Many CUDA kernels and tensor-core implementations are optimized for tile sizes and powers-of-two factors.
- Awkward dimensions (e.g., vocab size 50257) can trigger slow boundary kernels and reduce throughput.
Mitigation:
-
Pad the vocabulary to a nearby friendly number (e.g., 50304) to align internal loops with preferred tile sizes.
Trade-offs: - Small extra memory used for padded rows; functionally benign if token indices never reference padded rows.
- Optimizer must learn to drive unused-token logits down, which is usually negligible cost.
Optimizer choices: AdamW, fused optimizer kernels, and parameter grouping
AdamW (Adam with decoupled weight decay) is the recommended optimizer for Transformer training due to adaptive moments and per-parameter scaling.
Performance tips:
- Use fused implementations when available to reduce kernel-launch overhead by consolidating updates into single kernels.
- Apply parameter grouping to separate parameters that should receive weight decay (e.g., 2D weight matrices, embeddings) from those that should not (biases, LayerNorm gain/bias).
Result:
- Better optimization behavior and improved runtime when grouped and fused updates are used.
Gradient clipping and learning-rate scheduling (cosine decay with warmup)
Stability and scheduling:
-
Clip global gradient norm (e.g., to 1.0) to prevent runaway updates from outlier batches.
- Use a linear warmup followed by cosine decay: ramp LR from near zero to peak over a warmup token budget, then decay via cosine to a lower fraction (e.g., 10%) across a token horizon.
Implementation:
- Make warmup steps and cosine decay configurable to reproduce reference regimes and allow experimentation.
Gradual batch scaling considerations and practical omission
Batch-size strategies:
- Large-scale papers sometimes recommend ramping batch size early for optimizer stability, but this complicates bookkeeping and token-based scheduling.
- For single-node GPU constraints, gradient accumulation is an effective practical alternative to simulate larger global batch sizes while keeping micro-batches small.
Recommendation:
- For reproducible and simpler experiments, prefer fixed micro-batch sizes with gradient accumulation unless infrastructure supports dynamic batch scheduling.
Gradient accumulation semantics and correct loss scaling
Gradient accumulation mechanics:
- Sum per-micro-batch gradients across multiple forward/backward steps to emulate a larger batch.
- Because many PyTorch losses perform mean reduction over the micro-batch, scale the per-step loss by 1 / grad_accum_steps before backward so summed gradients equal a single large-batch gradient.
Correct pattern:
- loss = loss / grad_accum_steps
- loss.backward() each micro-step
- optimizer.step() only after accumulation is complete
Failure to scale yields gradients larger by grad_accum_steps and incorrect updates.
Distributed Data Parallel (DDP) fundamentals and synchronization control
Distributed Data Parallel (DDP) patterns:
- Run one process per GPU; each process computes local gradients on its data shard and then gradients are all-reduced (averaged) across processes before optimizer.step().
Gradient accumulation with DDP:
- Use no_sync context to avoid synchronizing on every micro-step during accumulation so only the final micro-step triggers communication.
Operational necessities:
- Initialize process ranks, world size, local rank, and device mapping correctly and destroy process groups on exit for robust DDP training.
Training data choices and high-quality CommonCrawl subsets (FineWeb/edu)
Datasets and preprocessing for reproductions:
- Public reproductions use curated mixtures built from CommonCrawl (OpenWebText, RedPajama, FineWeb) plus filtered sources (Wikipedia, books, GitHub).
- High-quality filtered subsets (e.g., FineWeb EDU) provide dense, high-information text for language-generalization metrics.
Practical processing:
- Tokenize at scale and shard into fixed-size files (e.g., 100M-token shards) for streaming and parallel loading.
- Careful deduplication, language filtering, and per-epoch shuffling are crucial to avoid dataset bias and leakage.
Training runs, throughput tuning and producing checkpoints
Combine optimizations to maximize throughput:
- Use BF16/autocast, torch.compile, FlashAttention, fused AdamW, padded dimensions, DDP, and gradient accumulation to increase tokens-per-second.
Operational monitoring:
- Measure tokens/sec, wall time per step, and validation intervals to estimate training budget in hours for a target token count.
- Regularly checkpoint model and optimizer state and persist RNG seeds for exact resume behavior.
Outcome:
- Empirical runs with these optimizations achieved successful convergence and downstream gains in modest compute budgets.
Downstream evaluation: validation loss, H-SWAG implementation and caveats
Evaluation strategy:
- Compute held-out validation next-token loss and run downstream multiple-choice-style benchmarks (e.g., H-SWAG) by converting each question into candidate continuations and scoring average per-token log-probabilities.
Distributed evaluation:
- Shard evaluation across DDP ranks, aggregate counts with all-reduce, and report global accuracy.
Interpretation caveats:
- Differences in training data distributions, possible test contamination from large scraped corpora, and limitations of older benchmarks (H-SWAG is largely solved by modern LMs) mean multiple held-out tasks should be used for robust comparison.
Production considerations: logging, checkpoints, alternative C/CUDA implementations and summary
Operational best practices:
- Structured logging of training and validation metrics.
- Periodic checkpointing (model + optimizer states) and reproducible seeds.
Advanced options:
- Lower-level implementations (e.g., a dedicated C/CUDA lm.C) can improve startup and per-step throughput but require careful numerical parity checks versus PyTorch prototypes.
Conclusion:
- With modern tooling and hardware, a faithful GPT-2 124M reproduction (implementation, training, and evaluation) is feasible on modest compute budgets.
- Recommended extensions: epoch shuffling, improved dataset handling, and compilation fixes for generation to make the reproduction robust and production-ready.
Enjoy Reading This Article?
Here are some more articles you might like to read next: