Grokking Modular Addition Transformer

A 1-layer transformer trained on modular addition (a + b) mod 113 that exhibits grokking -- the phenomenon where the model first memorizes the training data, then suddenly generalizes to the test set after continued training.

This model is a reproduction of the setup from Progress Measures for Grokking via Mechanistic Interpretability (Nanda et al., 2023), built with TransformerLens.

Model Description

The model learns a Fourier-based algorithm to perform modular addition:

  1. Embed inputs a and b into Fourier components (sin/cos at key frequencies)
  2. Attend from the = position to a and b, computing sin(ka), cos(ka), sin(kb), cos(kb)
  3. MLP neurons compute cos(k(a+b)) and sin(k(a+b)) via trigonometric identities
  4. Unembed maps these to logits approximating cos(k(a+b-c)) for each candidate output c

Architecture

Parameter Value
Layers 1
Attention heads 4
d_model 128
d_head 32
d_mlp 512
Activation ReLU
Normalization None
Vocabulary (input) 114 (0-112 for numbers, 113 for =)
Vocabulary (output) 113
Context length 3 tokens: [a, b, =]
Parameters ~2.5M

Design choices (no LayerNorm, ReLU, no biases) were made to simplify mechanistic interpretability analysis.

Usage

Loading the checkpoint

import torch
from transformer_lens import HookedTransformer

# Download and load
cached_data = torch.load("grokking_demo.pth", weights_only=False)

model = HookedTransformer(cached_data["config"])
model.load_state_dict(cached_data["model"])

# Training history is also included
model_checkpoints = cached_data["checkpoints"]       # 250 intermediate checkpoints
checkpoint_epochs = cached_data["checkpoint_epochs"]  # Every 100 epochs
train_losses = cached_data["train_losses"]
test_losses = cached_data["test_losses"]
train_indices = cached_data["train_indices"]
test_indices = cached_data["test_indices"]

Running inference

import torch

p = 113
a, b = 37, 58
input_tokens = torch.tensor([[a, b, p]])  # [a, b, =]
logits = model(input_tokens)
prediction = logits[0, -1].argmax().item()
print(f"{a} + {b} mod {p} = {prediction}")  # Should print 95

Installation

pip install torch transformer-lens

Training Details

Setting Value
Task (a + b) mod 113
Total data 113^2 = 12,769 pairs
Train split 30% (3,830 examples)
Test split 70% (8,939 examples)
Optimizer AdamW
Learning rate 1e-3
Weight decay 1.0
Betas (0.9, 0.98)
Epochs 25,000
Batch size Full batch
Checkpoints Every 100 epochs (250 total)
Seed 999 (model), 598 (data split)
Training time ~2 minutes on GPU

The high weight decay (1.0) is critical for grokking -- it gradually erodes memorization weights in favor of the compact generalizing Fourier circuit.

Grokking Phases

The training exhibits three distinct phases:

  1. Memorization (~epoch 0-1,500): Train loss drops to ~0, test loss stays at ~4.73 (random guessing over 113 classes). The model memorizes all training examples.
  2. Circuit Formation (~epoch 1,500-13,300): The Fourier-based generalizing circuit gradually forms in the weights, but memorization still dominates. Test loss appears unchanged.
  3. Cleanup (~epoch 13,300-16,600): Weight decay erodes memorization faster than the compact Fourier circuit. Test loss suddenly drops -- this is the grokking moment.

Mechanistic Interpretability Findings

Analysis of the trained model reveals:

  • Fourier-sparse embeddings: The model learns embeddings concentrated on key frequencies (k = 9, 33, 36, 38, 55)
  • Neuron clustering: ~85% of MLP neurons are well-explained by a single Fourier frequency
  • Logit periodicity: Output logits approximate cos(freq * 2pi/p * (a + b - c)) for key frequencies
  • Progress measures: Restricted loss and excluded loss track the formation and cleanup of circuits independently, revealing that grokking is not a sudden phase transition but the delayed visibility of a gradually forming algorithm

Source Code

Full analysis notebook and training code: GitHub repository

References

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Papers for BurnyCoder/grokking-modular-addition-transformer