| # Muon Optimizer: Implementation Guide | |
| This document explains the internal architecture of the Muon optimizer for reviewers and new contributors. It covers the execution paths, the parallel pipeline design, and the distributed sharding utilities. | |
| ## Table of Contents | |
| 1. [Overview](#overview) | |
| 2. [Entry Point and Parameter Routing](#entry-point-and-parameter-routing) | |
| 3. [Execution Paths](#execution-paths) | |
| 4. [Parallel Pipeline (the core feature)](#parallel-pipeline) | |
| 5. [MoE Expert Weight Support](#moe-expert-weight-support-expert_keys) | |
| 6. [Distributed Utilities](#distributed-utilities) | |
| 7. [Newton-Schulz Orthogonalization](#newton-schulz-orthogonalization) | |
| 8. [QK Clipping](#qk-clipping) | |
| 9. [AdamW for Non-Muon Parameters](#adamw-for-non-muon-parameters) | |
| 10. [Source File Map](#source-file-map) | |
| --- | |
| ## Overview | |
| Muon (MomentUm Orthogonalized by Newton-schulz) applies standard SGD-momentum and then replaces each 2D parameter's update with the nearest orthogonal matrix via a Newton-Schulz iteration. The iteration runs stably in bfloat16 on GPU. | |
| The optimizer supports arbitrary N-D sharding configurations: FSDP2, TP, or hybrid setups like `2 TP x 2 DP-Replicate x 2 DP-Shard`. This generality is what drives most of the code complexity. | |
| ## Entry Point and Parameter Routing | |
| **File:** `muon.py` — `Muon.step()` / `Muon._step_muon()` | |
| Users must provide parameter groups with `use_muon=True/False` flags (via `get_default_muon_param_groups()`). At each step: | |
| 1. **Non-Muon groups** → `step_adamw()` (fused AdamW). | |
| 2. **Muon groups** → `_step_muon()`, which further classifies each parameter: | |
| ``` | |
| _step_muon(group) | |
| | | |
| +-- momentum update (batched _foreach_* ops) | |
| +-- _expand_expert_params() -- 3D expert params → per-expert 2D views (cached) | |
| | | |
| +-- DTensor, all Replicate placements --> base() (no sharding) | |
| +-- DTensor, sharded --> parallel() (pipelined all-to-all) | |
| +-- plain Tensor --> base() (single device) | |
| ``` | |
| Parameters are classified by their DTensor placements: | |
| - **Fully replicated** DTensors and plain tensors use `base()` — standard single-device Muon. | |
| - **Sharded** DTensors use `parallel()` — the pipelined all-to-all approach described below. | |
| - `distributed_muon()` exists as a **test-only reference implementation** for correctness verification. | |
| ## Execution Paths | |
| ### base() — Single Device | |
| Straightforward per-parameter loop: momentum update → Newton-Schulz orthogonalization → parameter update → optional QK clipping. | |
| ### distributed_muon() — Full Gather (test-only) | |
| Reference implementation for correctness verification. Uses batched all-gather to reconstruct full tensors, computes Newton-Schulz on the full grad, then slices back to local shards. Simple but communication-heavy — not used in production. | |
| ### parallel() — Pipelined All-to-All | |
| This is the main advanced feature. Instead of all-gathering the full parameter, it uses **all-to-all** to distribute work: each rank "owns" a subset of parameters and is responsible for their Newton-Schulz computation. | |
| ## Parallel Pipeline | |
| ### Design Motivation | |
| Newton-Schulz is compute-intensive. The key insight is that each rank only needs to orthogonalize the parameters it "owns" — not all parameters. So the flow is: | |
| 1. **Gather**: Each rank sends its local gradient shard to the owning rank via all-to-all. | |
| 2. **Compute**: The owning rank runs Newton-Schulz on the full (gathered) gradient. | |
| 3. **Scatter**: The owning rank sends the orthogonalized update back to all ranks via all-to-all. | |
| 4. **Update**: Each rank applies weight decay and the update to its local shard. | |
| To overlap communication and computation, parameters are split into **chunks**, and multiple chunks are processed concurrently. | |
| ### Architecture | |
| ``` | |
| muon.py: parallel() | |
| | | |
| +-- init_state_and_assign_params() -- assigns ownership, precomputes indices | |
| | | |
| +-- pipelines() generator -- yields muon_chunk_pipeline() per chunk | |
| | | |
| +-- run_pipeline(pipelines, max_concurrent=warmup_step+1) | |
| | | |
| +-- interleaves chunks at yield boundaries | |
| ``` | |
| ### The Chunk Pipeline Generator | |
| **File:** `pipeline.py` — `muon_chunk_pipeline()` | |
| Each chunk is a generator that yields **2 times**, creating stages separated by async communication: | |
| ``` | |
| YIELD 1 YIELD 2 | |
| | | | |
| [Build bufs + async gather a2a] --> [wait + NS compute + async scatter a2a] --> [wait + Update params] | |
| ``` | |
| - **Async communication**: `dist.all_to_all_single(..., async_op=True)` launches non-blocking communication. The generator yields immediately after, allowing other chunks to run. `work.wait()` completes the operation after the yield. | |
| - **Chunk-level overlap**: `run_pipeline()` interleaves multiple chunks at yield boundaries, so while chunk N waits for its communication, chunk N+1 can launch its own. | |
| ### The Pipeline Scheduler | |
| **File:** `async_utils.py` — `run_pipeline()` | |
| A simple round-robin scheduler: | |
| ```python | |
| while have_new or previous_tasks: | |
| # Admit one new pipeline if below concurrency limit | |
| if have_new and len(previous_tasks) < max_concurrent: | |
| task = next(pipelines) # runs to first yield | |
| # Advance all existing tasks by one yield | |
| for task in previous_tasks: | |
| task.step() # runs to next yield | |
| ``` | |
| `max_concurrent = warmup_step + 1` controls how many chunks can be in-flight simultaneously. Higher values increase memory usage but improve communication/computation overlap. | |
| ### Ownership Assignment | |
| **File:** `muon.py` — `init_state_and_assign_params()` | |
| Parameters are sorted by FLOP cost (descending) and assigned to ranks in round-robin order across the shard mesh. This balances compute load across ranks. | |
| ### Precomputed Shard Indices | |
| Instead of computing per-rank shard indices on every step, they are precomputed once during `init_state_and_assign_params()` and stored in `_muon_state`: | |
| ```python | |
| @dataclass | |
| class _muon_state: | |
| worker_rank: int # which rank owns this param's computation | |
| process_group: ProcessGroup # the all-to-all communication group | |
| rank_indices: dict[int, tuple] # rank -> per-dim indices into full tensor | |
| rank_numels: dict[int, int] # rank -> number of elements in shard | |
| name: str | |
| qk_clip_state: QKClipInfo | None | |
| ``` | |
| `rank_indices[r]` is a tuple of `slice` or `torch.Tensor` per dimension, describing which elements of the full tensor rank `r` owns. `rank_numels[r]` is the total number of elements in that shard. These are used directly in the pipeline's gather and scatter stages. | |
| ### Pipeline Stages in Detail | |
| #### Stages 1-2: Gather | |
| 1. **Allocate** receive buffers for gathered gradients (only on owning ranks). | |
| 2. **Build send buffer**: Each rank flattens its local gradient shard for each destination rank. | |
| 3. **Async all-to-all**: `dist.all_to_all_single(..., async_op=True)` launches gather. | |
| 4. **Yield 1**: Other chunks can launch their gather while this one waits. | |
| 5. **`work.wait()`**: Complete the gather. | |
| 6. **Reconstruct**: The owning rank places received shards into the full gradient using `rank_indices`. | |
| #### Stage 3: Compute | |
| The owning rank runs `_zeropower_via_newtonschulz5()` on the full gathered gradient. This is the most compute-intensive stage. Runs inline (no yield) since it is synchronous GPU work. | |
| #### Stages 4-5: Scatter | |
| Inverse of gather: | |
| 1. **Allocate** receive buffers for the orthogonalized update `U`. | |
| 2. **Build send buffer**: The owning rank slices `U` using `rank_indices` for each destination rank. | |
| 3. **Async all-to-all**: `dist.all_to_all_single(..., async_op=True)` launches scatter. | |
| 4. **Yield 2**: Other chunks can launch their scatter while this one waits. | |
| 5. **`work.wait()`**: Complete the scatter. | |
| 6. **Copy** received shards into local update buffers. | |
| #### Stage 6: Update | |
| Each rank applies weight decay and the Muon update to its local parameter shard. Also applies QK clipping if configured. | |
| ## MoE Expert Weight Support (`expert_keys`) | |
| **File:** `muon.py` — `_expand_expert_params()` | |
| MoE models have 3D expert weights with shape `(num_experts, out_dim, in_dim)`. Since Muon operates on 2D matrices, expert params need special handling. | |
| ### Configuration | |
| Pass `expert_keys` to both `get_default_muon_param_groups()` and `Muon()`: | |
| ```python | |
| params = get_default_muon_param_groups(model, expert_keys=["experts"]) | |
| optim = Muon(params, expert_keys=["experts"], ...) | |
| ``` | |
| Any parameter whose name contains a string in `expert_keys` is treated as an expert-parallel parameter. Non-matching 3D+ parameters raise `AssertionError` to catch misconfiguration. | |
| ### How It Works | |
| `_expand_expert_params()` runs after momentum and before routing to `base()`/`parallel()`/`distributed_muon()`: | |
| 1. **Split on dim 0**: A 3D `(E, out, in)` tensor becomes `E` separate 2D `(out, in)` `nn.Parameter` views. Views share storage with the original, so in-place updates propagate back. | |
| 2. **Placement remapping**: When the original is a DTensor, `Shard(k)` on dim `k > 0` becomes `Shard(k-1)` on the 2D slice (since dim 0 is consumed by the split). | |
| 3. **Submesh wrapping**: Non-dim-0 shard placements are preserved by wrapping each 2D slice as a DTensor on the corresponding submesh. This is **placement-agnostic** — the same logic handles TP `Shard(1/2)`, EFSDP `Shard(1)`, or any other non-dim-0 sharding. | |
| ### Placement-Agnostic Design | |
| The expansion logic does not care *why* a dimension is sharded — only whether it's on dim 0 (consumed by split) or not (preserved on submesh): | |
| | Original Placement | After Expansion | | |
| |-------------------|-----------------| | |
| | `Shard(0)` (EP) | Consumed by split → plain tensor | | |
| | `Shard(1)` (TP or EFSDP) | `Shard(0)` on submesh → 2D DTensor | | |
| | `Shard(2)` (TP row-wise) | `Shard(1)` on submesh → 2D DTensor | | |
| | `Replicate` | Ignored (not a shard) | | |
| | `_StridedShard(0)` (EFSDP) | Consumed by split → plain tensor | | |
| After expansion, the 2D params flow through the standard routing: DTensors with shard placements go to `parallel()`, plain tensors go to `base()`. | |
| For EP/EFSDP background and torchtitan integration details, see [`docs/expert_parallel.md`](expert_parallel.md). | |
| ## Distributed Utilities | |
| **File:** `distributed/utils.py` | |
| These utilities solve the problem of mapping from a DTensor's arbitrary sharding configuration to the concrete indices each rank owns. | |
| ### `construct_shard_mesh(placements, mesh)` | |
| Given a DTensor's placements and device mesh, this function: | |
| 1. **Sorts** placements: Replicate dims first, then Shard dims by dimension (with `_StridedShard` before regular `Shard` on the same dim, so the outer sharding is applied first). | |
| 2. **Permutes** the mesh accordingly. | |
| 3. **Separates** replicate dims from shard dims — each replicate group gets its own shard sub-mesh. | |
| 4. **Creates** a ProcessGroup for the current rank's shard mesh. | |
| Returns `(shard_mesh, process_group, shard_placements)` — used for all-to-all communication. | |
| **Why this is needed:** A model might use `[Replicate, Shard(0), _StridedShard(0)]` across a 3D mesh. The optimizer needs to identify which ranks participate in the same shard group (share the same data) and create a ProcessGroup for them. | |
| ### `get_slices_of_dtensor(target, local_rank, shard_mesh, shard_placements)` | |
| Computes the exact indices that a given rank owns in the full tensor. Handles both contiguous (`Shard`) and strided (`_StridedShard`) sharding, including composed multi-level sharding on the same dimension. | |
| Returns a tuple of `slice` (contiguous) or `torch.LongTensor` (strided) per dimension. | |
| **Example:** With `[Shard(0), _StridedShard(0)]` on a (16, 2048) tensor across 4 ranks: | |
| - Rank 0 might own rows `[0, 4, 8, 12]` (strided) | |
| - Rank 1 might own rows `[1, 5, 9, 13]` | |
| - etc. | |
| ### PyTorch 2.10 Compatibility | |
| In PyTorch 2.10, `_StridedShard` no longer inherits from `Shard`. The helper `_is_shard()` handles both old and new hierarchies: | |
| ```python | |
| def _is_shard(placement): | |
| return isinstance(placement, (Shard, _StridedShard)) | |
| ``` | |
| ## Newton-Schulz Orthogonalization | |
| **File:** `newton_schulz.py` | |
| `_zeropower_via_newtonschulz5()` computes the polar factor of a matrix using the Polar Express method — quintic Newton-Schulz iterations with analytically optimal (minimax/Remez) coefficients precomputed by `_optimal_composition()`. The default configuration uses 10 iterations with `l=1e-3`, converging all singular values to 1 to produce the exact polar factor `UV^T`. Wrapped by `zeropower_via_newtonschulz5()` which adds per-shape `torch.compile` caching with CUDA graph support. | |
| Each iteration uses `matmul_transpose_assign()` (a Triton kernel for `X @ X^T`) for efficiency. | |
| **File:** `matmul_transpose_triton.py` | |
| The `matmul_transpose_assign(d_in, d_out)` kernel computes `d_out = d_in @ d_in^T` in-place. It exploits symmetry by computing only upper-triangle blocks and mirroring. | |
| ## QK Clipping | |
| **File:** `qk_clip.py` | |
| Optional dynamic clipping for attention head projections (Q and K weight matrices). When the maximum QK logit for a head exceeds a threshold, the corresponding rows of the weight matrix are scaled down by `sqrt(threshold / logit)`. | |
| **In the parallel pipeline:** QK clipping is applied per-row using each row's global head index. This correctly handles strided sharding where local rows may be interleaved across multiple heads: | |
| ```python | |
| # pipeline.py: _update_params() | |
| ratio = p.shape[0] // scales_full.shape[0] # rows per head | |
| idx0 = state.rank_indices[rank][0] # which global rows this rank owns | |
| row_scales = scales_full[idx0 // ratio] # map each row to its head's scale | |
| p._local_tensor.mul_(row_scales.view(-1, 1)) | |
| ``` | |
| ## AdamW for Non-Muon Parameters | |
| **File:** `adamw.py` | |
| Parameters not eligible for Muon (1D parameters, embeddings, LM head) are optimized with fused AdamW via `torch._fused_adamw_`. Parameters are grouped by device/dtype and DTensor placement before the fused call. | |
| ## Source File Map | |
| | File | Lines | Purpose | | |
| |------|-------|---------| | |
| | `muon.py` | ~815 | Optimizer class, parameter routing, 3 execution paths, MoE expert expansion + caching | | |
| | `pipeline.py` | ~400 | Generator-based parallel pipeline (gather/compute/scatter/update) | | |
| | `async_utils.py` | ~75 | Pipeline scheduler with bounded concurrency | | |
| | `core.py` | ~175 | `_muon_state` dataclass, batched momentum/update helpers, param grouping | | |
| | `distributed/utils.py` | ~230 | Shard mesh construction, DTensor index computation | | |
| | `newton_schulz.py` | ~190 | Polar Express coefficients, Newton-Schulz iteration + compile/CUDA graph | | |
| | `matmul_transpose_triton.py` | ~130 | Triton kernel for symmetric matmul | | |
| | `qk_clip.py` | ~135 | QK logit clipping | | |
| | `adamw.py` | ~170 | Fused AdamW for non-Muon params | | |
| ### Dependency Graph | |
| ``` | |
| matmul_transpose_triton.py (leaf) | |
| | | |
| newton_schulz.py (leaf + triton) | |
| | | |
| core.py ---- qk_clip.py (leaf, distributed/utils) | |
| | | | | |
| | pipeline.py --- async_utils.py | |
| | | | |
| | adamw.py | |
| | | | |
| muon.py (all above) | |
| | | |
| __init__.py | |
| ``` | |