FSDP vs Tensor Parallel: Which Sharding Strategy
2025-06-10
When to reach for FSDP, when to reach for tensor parallelism, and why the answer depends entirely on your interconnect.
Modern large-model training requires sharding the model across multiple GPUs. There are two dominant strategies: FSDP (Fully Sharded Data Parallel) and tensor parallelism. They aren't competitors — they're tools for different bottlenecks.
FSDP
FSDP shards parameters, gradients, and optimizer state across data-parallel ranks. At each forward step, ranks all-gather the parameters they need, run the forward, discard, and move on. Backward does the same with gradients.
- Communication is all-gather + reduce-scatter, both bandwidth-heavy.
- Each rank still computes on its own chunk of the batch.
- Works well when your interconnect can keep up with the all-gathers (Infiniband, NVLink across all ranks).
Tensor parallel
Tensor parallelism slices each linear layer's weight matrix across ranks. The forward pass requires an all-reduce at the end of each layer to combine partial results.
- Communication is per-layer all-reduce — many small messages.
- Latency-bound: each layer's all-reduce blocks the next layer.
- Only works inside a single high-bandwidth domain (one DGX node, NVLink-connected).
The decision tree
- Model fits on one GPU? Use plain DDP. Don't shard.
- Model fits on one node? Tensor parallel inside the node, DDP across nodes.
- Model doesn't fit on one node? FSDP across the cluster, optionally with tensor parallel inside each node (3D parallelism).
What I used for SpatialDINO
The 3D ViT for SpatialDINO fit comfortably on one DGX A100 (8× A100 80GB). I used
FSDP with MixedPrecision(bf16) and FULL_SHARD, plus activation checkpointing on
the encoder. This gave us room to scale batch size without OOM.
Tensor parallel would have been overkill — and the per-layer all-reduce would have hurt step time on a model that already trained fast. The lesson: pick the simplest strategy that fits your model, not the most sophisticated one available.