ZeRO: Data parallelism method to train large models

GPT-3 has 175 billion parameters. An A100 GPU — one of the most powerful available — has 80 GB of memory. You would need at least 10 A100s just to store the weights before training even begins.

The obvious fix is data parallelism: buy more GPUs, split the training data across them, and average the gradients at the end of each step. More GPUs, more throughput. But there is a catch.

More GPUs Doesn't Mean Less Memory automatically

In standard data-parallel training (the default DDP module for data parallel in pytorch) every GPU holds a complete copy of the model — weights, gradients, and optimizer states. Drag the slider below and watch what happens to memory per GPU as you add more:

Dashed line = what ideal memory scaling would look like. The red line never moves — every GPU still holds the full model.

In traditional data parallelism, adding GPUs scales your compute, not your memory automatically.

What's Actually Inside That Memory?

Mixed-precision training with the Adam optimizer stores three things per parameter, and one of them is enormous:

  • Optimizer states — 12 bytes/param. Adam keeps a full-precision (fp32) copy of every weight, plus momentum and variance tensors. That is 75% of your total memory footprint — and none of it needs to be replicated.
  • Gradients — 2 bytes/param. Computed during backprop, averaged across GPUs after each step.
  • Parameters — 2 bytes/param. The model weights in fp16.

Together: 16 bytes per parameter on every GPU. For a 7B-parameter model that is 112 GB — already more than one A100 can hold.

ZeRO Eliminates Redundancy in Three Stages

ZeRO (Zero Redundancy Optimizer) partitions each memory bucket across all N GPUs so that each GPU owns only 1/N of that bucket. Select a stage to watch the bar shrink:

■ Optimizer States (12 bytes/param) ■ Gradients (2 bytes/param) ■ Parameters (2 bytes/param)
Every GPU holds a full copy of everything — weights, gradients, and optimizer states. With 64 GPUs you have 64 identical copies of the same data. Pure redundancy.

Stage 1: Split the Biggest Boxes

The optimizer states are the biggest offender — and they are private. After the gradient allreduce, GPU 0 uses the averaged gradient to update its slice of the optimizer state. GPU 1 updates its slice. No GPU ever needs another GPU's optimizer slice. So why replicate it?

With 64 GPUs, ZeRO-1 cuts optimizer memory from 12P to 12P/64. Total per-GPU footprint drops from 16P to roughly 4.2P bytes.

Stage 2: Reduce-Scatter Gradients, Allgather Updated Params

ZeRO-2 replaces the gradient allreduce with a reduce-scatter: each GPU partitions its local gradient gi into N chunks and sends each chunk to its designated owner. GPU i receives the i-th chunk from every other GPU, sums them, and keeps only that shard — cutting gradient memory from 2P to 2P/N.

After the optimizer step, each GPU has only updated the parameters for its own shard. Before the next forward pass, an allgather restores the full updated model on every GPU. The allgather (Ψ) plus the reduce-scatter (Ψ) sum to 2Ψ — identical to DDP. ZeRO-2 saves memory, not bandwidth.

Stage 3: Borrow Parameters, Use Them, Return Them

Stage 3 shards the parameters themselves. Before each transformer layer's forward pass, an allgather reconstructs the full layer from all GPU shards — like borrowing a book from a neighbor, reading it, and returning it before the next chapter. After the pass, non-owner GPUs discard those parameters.

Memory scales as 16P/N — a perfect linear improvement. The cost is one allgather per layer per forward and backward pass.

The Communication Pattern Changes Too

ZeRO does not reduce the total bytes transferred — it changes when and how data moves between GPUs. Switch stages to see the difference:

Overview — DDP: every GPU holds the full Optimizer States (OS), Gradients, and Parameters. Pure redundancy — no memory savings. Memory: 16P per GPU.

Chunk 0 → GPU 0Chunk 1 → GPU 1Chunk 2 → GPU 2Chunk 3 → GPU 3 Not stored here AllReduce (global sync)

DDP AllReduce = 2Ψ. Full OS on every GPU → no AllGather needed. Memory: 16P per GPU.

The Core Tradeoff: Memory vs. Communication

ZeRO converts memory redundancy into time-distributed communication. Stages 0–2 have identical communication cost to DDP. Stage 3 adds allgather latency spread across every layer — the same total bytes, more round-trips.

Left axis: per-GPU memory (GB) · Right axis: communication overhead relative to DDP baseline (1.0×)

How Much Does Each Stage Actually Save?

Adjust the model size to rescale all four curves. Hover any line for the exact value. The breakdown below shows exactly what each stage's memory is made of at the current GPU count.

■ Optimizer■ Gradients■ Params · bar width = GB · updates with GPU count and model size

Can Your Hardware Actually Fit the Model?

Pick a model size and GPU type. The solver finds the minimum ZeRO stage and GPU count that fits within your memory budget.

To train a 7B parameter model on 80 GB GPUs:

  • Minimum ZeRO stage: ZeRO-1: Shard Optimizer States
  • Minimum GPU count: 2 GPUs
  • Memory per GPU at these settings: 70.0 GB (88% utilization)

The Bottom Line

ZeRO-3 with 512 A100s can train GPT-3 (175B params) using roughly 5.5 GB per GPU — down from 2,800 GB of aggregate memory if every GPU held a full redundant copy. The hardware budget is the same; the memory is just no longer wasted.

This is why DeepSpeed + ZeRO became the default training stack for most large language models after 2020. Not magic — the obvious fix once you see the redundancy.

Based on: Rajbhandari et al., ZeRO: Memory Optimizations Toward Training Trillion Parameter Models, SC'20. · arxiv:1910.02054