Grokking Modular Addition Transformer
A 1-layer transformer trained on modular addition (a + b) mod 113 that exhibits grokking -- the phenomenon where the model first memorizes the training data, then suddenly generalizes to the test set after continued training.
This model is a reproduction of the setup from Progress Measures for Grokking via Mechanistic Interpretability (Nanda et al., 2023), built with TransformerLens.
Model Description
The model learns a Fourier-based algorithm to perform modular addition:
- Embed inputs
aandbinto Fourier components (sin/cos at key frequencies) - Attend from the
=position toaandb, computingsin(ka),cos(ka),sin(kb),cos(kb) - MLP neurons compute
cos(k(a+b))andsin(k(a+b))via trigonometric identities - Unembed maps these to logits approximating
cos(k(a+b-c))for each candidate outputc
Architecture
| Parameter | Value |
|---|---|
| Layers | 1 |
| Attention heads | 4 |
| d_model | 128 |
| d_head | 32 |
| d_mlp | 512 |
| Activation | ReLU |
| Normalization | None |
| Vocabulary (input) | 114 (0-112 for numbers, 113 for =) |
| Vocabulary (output) | 113 |
| Context length | 3 tokens: [a, b, =] |
| Parameters | ~2.5M |
Design choices (no LayerNorm, ReLU, no biases) were made to simplify mechanistic interpretability analysis.
Usage
Loading the checkpoint
import torch
from transformer_lens import HookedTransformer
# Download and load
cached_data = torch.load("grokking_demo.pth", weights_only=False)
model = HookedTransformer(cached_data["config"])
model.load_state_dict(cached_data["model"])
# Training history is also included
model_checkpoints = cached_data["checkpoints"] # 250 intermediate checkpoints
checkpoint_epochs = cached_data["checkpoint_epochs"] # Every 100 epochs
train_losses = cached_data["train_losses"]
test_losses = cached_data["test_losses"]
train_indices = cached_data["train_indices"]
test_indices = cached_data["test_indices"]
Running inference
import torch
p = 113
a, b = 37, 58
input_tokens = torch.tensor([[a, b, p]]) # [a, b, =]
logits = model(input_tokens)
prediction = logits[0, -1].argmax().item()
print(f"{a} + {b} mod {p} = {prediction}") # Should print 95
Installation
pip install torch transformer-lens
Training Details
| Setting | Value |
|---|---|
| Task | (a + b) mod 113 |
| Total data | 113^2 = 12,769 pairs |
| Train split | 30% (3,830 examples) |
| Test split | 70% (8,939 examples) |
| Optimizer | AdamW |
| Learning rate | 1e-3 |
| Weight decay | 1.0 |
| Betas | (0.9, 0.98) |
| Epochs | 25,000 |
| Batch size | Full batch |
| Checkpoints | Every 100 epochs (250 total) |
| Seed | 999 (model), 598 (data split) |
| Training time | ~2 minutes on GPU |
The high weight decay (1.0) is critical for grokking -- it gradually erodes memorization weights in favor of the compact generalizing Fourier circuit.
Grokking Phases
The training exhibits three distinct phases:
- Memorization (~epoch 0-1,500): Train loss drops to ~0, test loss stays at ~4.73 (random guessing over 113 classes). The model memorizes all training examples.
- Circuit Formation (~epoch 1,500-13,300): The Fourier-based generalizing circuit gradually forms in the weights, but memorization still dominates. Test loss appears unchanged.
- Cleanup (~epoch 13,300-16,600): Weight decay erodes memorization faster than the compact Fourier circuit. Test loss suddenly drops -- this is the grokking moment.
Mechanistic Interpretability Findings
Analysis of the trained model reveals:
- Fourier-sparse embeddings: The model learns embeddings concentrated on key frequencies (k = 9, 33, 36, 38, 55)
- Neuron clustering: ~85% of MLP neurons are well-explained by a single Fourier frequency
- Logit periodicity: Output logits approximate
cos(freq * 2pi/p * (a + b - c))for key frequencies - Progress measures: Restricted loss and excluded loss track the formation and cleanup of circuits independently, revealing that grokking is not a sudden phase transition but the delayed visibility of a gradually forming algorithm
Source Code
Full analysis notebook and training code: GitHub repository
References
- Progress Measures for Grokking via Mechanistic Interpretability (Nanda et al., 2023)
- Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets (Power et al., 2022)
- TransformerLens