Papers
Topics
Authors
Recent
2000 character limit reached

SimpleFSDP: Efficient FSDP in PyTorch

Updated 30 December 2025
  • SimpleFSDP is a PyTorch-native framework that refactors FSDP training by embedding both computation and communication in the FX graph for full compiler tracing.
  • It leverages native DTensor sharding, custom parametrizations, and selective activation checkpointing to optimize memory usage and expose the entire dependency graph.
  • Benchmark results on Llama models show up to 28.54% memory reduction and 68.67% throughput improvement compared to legacy FSDP approaches.

SimpleFSDP is a PyTorch-native, compiler-based framework for Fully Sharded Data Parallel (FSDP) training, distinguished by its succinct implementation, deep integration with the torch.compiletorch.compile ecosystem, and ability to expose the entire computation-communication dependency graph for aggressive compiler backend optimizations. Its design leverages native PyTorch primitives—DTensor for device placement and sharding, parametrizations for collective communication, and selective activation checkpointing—to enable full tracing and composable distributed training. SimpleFSDP unlocks new capabilities in computation-communication overlap and memory efficiency, especially for training ultra-LLMs, by eliminating reliance on Python hooks and by exploiting advanced bucketing and reordering optimizations within the TorchInductor backend (Zhang et al., 1 Nov 2024).

1. High-Level Architecture and Workflow

SimpleFSDP refactors the conventional PyTorch FSDP2 eager approach, which relies on backward hooks, into a model where all communication and computation are embedded in the FX graph and thus traceable by %%%%1%%%%. The core workflow is as follows:

  1. Parameter Sharding: Each nn.Parameternn.Parameter is wrapped as a DTensor, sharded across the data-parallel mesh:

layout=Shard(dim=0,mesh="data_parallel")\text{layout} = \text{Shard}(dim=0, \text{mesh}="data\_parallel")

  1. Forward All-Gather: During forward computation, a parametrization module calls DTensor.redistributeDTensor.redistribute, gathering the full parameter into activations.
  2. Activation Cleanup: Activations are immediately released via checkpointing, reducing peak memory consumption.
  3. Backward Gradient Reduction: In the backward pass, DTensor.redistributeDTensor.redistribute all-gathers shards as needed and reduces/scatters gradients.

All operators are inside the graph, enabling inspection, rewriting, and optimization at the IR level. DTensor metadata self-manages device, rank, and mesh placement, obviating explicit dist.get_rank()dist.get\_rank() or NCCL communicator orchestration.

2. Collective Communications via PyTorch Primitives

Collective operations are transparently mapped to autograd-aware PyTorch primitives:

  • Parametrization API: Each parameter uses a custom parametrization. Its forward logic is,
    1
    2
    3
    4
    5
    6
    7
    
    def replicate_compute(self, x):
        y = x.redistribute(
            placements=(Replicate(),),
            forward_dtype=self.param_dtype,
            backward_dtype=self.reduce_dtype
        ).to_local(grad_placements=(Partial("avg"),))
        return y
    Replicate()Replicate() denotes replication across the mesh, and Partial(reduce_op="avg")Partial(reduce\_op="avg") codes reduce-scatter for gradients.
  • Selective Activation Checkpointing: Only the parametrization module is checkpointed, ensuring that (a) forward all-gather and local compute run once with immediate activation release, and (b) backward gathers and reduces gradients on demand.
  • DTensor Role: DTensor encapsulates placement and coordinates; on initialization, each parameter is sharded. This abstraction removes manual rank and communication management.

This approach guarantees full end-to-end tracing with all communication and computation both visible and optimizable for torch.compiletorch.compile.

3. TorchInductor Backend Optimizations

After FX→IR lowering, SimpleFSDP presents a long IR node sequence:

AG1;C1;Wa1;AG2;C2;Wa2;RS1;Wr1;RS2;Wr2;AG_1; C_1; Wa_1; AG_2; C_2; Wa_2; \ldots RS_1; Wr_1; RS_2; Wr_2; \ldots

where AGAG is all-gather, CC is compute, WaWa is wait for AG, RSRS is reduce-scatter, WrWr is wait for RS.

Bucketing IR Nodes

  • Bucketing Objective: Communications over multiple small shards are grouped (“bucketed”) so that one NCCL collective is issued per bucket.
  • Implementation: Input shards in1,in2in_1, in_2 are concatenated, collectively gathered, and outputs split. Pseudocode:
    1
    2
    3
    
    flat_in = torch.cat([in1, in2], dim=0)
    flat_out = Mesh.redistribute(flat_in, Replicate())
    out1, out2 = split(flat_out, [n1, n2])

Reordering for Overlap

  • Async Streams: All-gather/reduce-scatter are asynchronous, allowing movement of IR nodes for maximal overlap under dependency constraints.
  • Example Reorder: AG2C1Wa1C2Wa2AG_2 \rightarrow C_1 \rightarrow Wa_1 \rightarrow C_2 \rightarrow Wa_2 can be reordered as $AG_2 \rightarrow AG_1_{wait} \rightarrow C_1 \rightarrow Wa_2 \rightarrow C_2 \rightarrow Wait\,AG_1_{wait}$.

This exposes computation-communication concurrency and minimizes communication bubbles, subject to dependency analysis validity.

4. Mathematical Formalisms

Memory Usage

For layer \ell with parameter tensor size SS_\ell (after fp16/bf16 reduction):

  • Eager (no sharding): Mbefore=2SM_\ell^{before} = 2S_\ell (weights + gradients)
  • FSDP2: Mafter2SP+O(overlap_temp)M_\ell^{after} \approx \frac{2S_\ell}{P} + O(\text{overlap\_temp})
  • SimpleFSDP: Immediate freeing and delayed regather yields

Mafter=max(S,S/P)+overhead_checkpointM_\ell^{after} = \max(S_\ell, S_\ell/P) + \text{overhead\_checkpoint}

Usually SS/PS_\ell \gg S_\ell/P, so SimpleFSDP peaks at SS_\ell, not 2S2S_\ell or S/P+SS_\ell/P + S_\ell.

Communication Volume & Overlap Efficiency

  • Per-Shard All-Gather: V=(P1)PSV_\ell = \frac{(P-1)}{P}S_\ell bytes.
  • Bucketing: For BB layers, Vbucket=(P1)PSV_{bucket} = \frac{(P-1)}{P}\sum_\ell S_\ell
  • Overlap Efficiency: η=1exposed_comm_timetotal_comm_time\eta = 1 - \frac{\text{exposed\_comm\_time}}{\text{total\_comm\_time}}

These formulations quantify the memory and communication efficiency gains achievable by SimpleFSDP compared to legacy approaches.

5. Benchmark Results on Llama 3

Extensive evaluation on real Llama 3 models (8B, 70B, 405B) using TorchTitan orchestration, PyTorch 2.x with TorchInductor, and advanced hardware (16-node, 8×NVIDIA H100/node, NVLink, InfiniBand):

Model FSDP2-eager FSDP2-compile SimpleFSDP
8B 47K tok/s, 86.8GB 52K tok/s, 79.3GB 52K tok/s, 62.8GB
70B 8.5K tok/s, 224GB 9.3K tok/s, 208GB 10.3K tok/s, 187GB
405B 1.1K tok/s, 650GB 1.2K tok/s, 610GB 1.8K tok/s, 544GB
  • Memory reduction: Up to 28.54% vs FSDP2-eager; 8.37% vs FSDP2-compile.
  • Throughput improvement: Up to 68.67% vs FSDP2-eager; 6.06% vs FSDP2-compile.

A plausible implication is that, when composed with distributed training and parallelism techniques, SimpleFSDP enables scalable training of ultra-large models with reduced memory footprint and improved token throughput (Zhang et al., 1 Nov 2024).

6. Usage Patterns and Composability

Several usage recipes are described:

Automatic & Manual Wrapping

Configuration of automatic bucketing and reorder:

1
2
3
4
5
import torch
torch._inductor.config.simplefsdp.bucket_mode = "auto"
torch._inductor.config.simplefsdp.enable_reorder = True
model = simple_fsdp(model)
model = torch.compile(model, fullgraph=True)
Manual wrapping of model components:
1
2
3
4
from simple_fsdp import ModuleWrapList
wrap_list = ["TransformerBlock1", "TransformerBlock2", ...]
model = simple_fsdp(model, wrap_modules=wrap_list)
model = torch.compile(model, fullgraph=True)

Integration with Distributed Techniques

SimpleFSDP is composable with tensor parallel and pipeline parallel setups such as TorchTitan:

1
2
3
4
5
6
7
8
from torch_titan import tensor_parallel, pipeline_parallel
model = meta_init(model)
model = model.half()
model = activation_checkpoint(model)
model = tensor_parallel(model, tp_degree=8)
model = pipeline_parallel(model, pp_degree=4)
model = simple_fsdp(model)
model = torch.compile(model, fullgraph=True)

Composability with parallelism orchestrators is achieved with minimal API surface.

7. Limitations and Prospective Directions

  • Auto-wrapping employs an α+βn\alpha + \beta n NCCL time model and greedy bucketing may overselect for memory-heavy buckets. Future work could employ global ILP or topology-aware models.
  • Graph breaks, typically caused by Python data-dependent control flow, limit global comm/compute overlap. FX graph knitting is identified as a future research direction.
  • Extension of DTensor mesh abstraction for cross-data-center or heterogeneous device meshes is unexplored.
  • Potential integration with fully-automated planners (Alpa, nnScaler, Slapo) could exploit SimpleFSDP’s exposed computation-communication graph for joint data/tensor/pipeline planning and bubble minimization.

In summary, SimpleFSDP is a PyTorch-native re-implementation of ZeRO-3 fully sharded data parallelism, purpose-built for end-to-end compiler tracing, unlocking compiler-level optimizations such as IR bucketing and reordering, and delivering significant improvements in memory and throughput for large-scale distributed model training (Zhang et al., 1 Nov 2024).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Whiteboard

Topic to Video (Beta)

Follow Topic

Get notified by email when new papers are published related to SimpleFSDP.