Papers
Topics
Authors
Recent
Search
2000 character limit reached

Gradient Checkpointing Techniques

Updated 21 April 2026
  • Gradient checkpointing is a memory optimization technique that reduces peak memory usage by selectively storing checkpoints and recomputing missing intermediates.
  • It trades additional computation time for significant memory savings, enabling the training of larger deep learning models and complex simulations on constrained hardware.
  • Various methods, including Revolve, double, and multi-axis checkpointing, balance runtime efficiency and memory usage, with high-level integrations in modern frameworks.

Gradient checkpointing is a memory optimization technique that enables the training and adjoint computation of large-scale models, such as deep neural networks or PDE simulations, under limited hardware memory. The method operates by selectively storing a subset of intermediate activations or simulation states during the forward computation, then recomputing missing intermediates as needed during the backward (gradient) pass. This trades additional compute time for a significantly reduced peak memory footprint, enabling the solution of problems with higher resolution, longer sequences, or larger model size on constrained hardware. Gradient checkpointing is central to modern deep learning frameworks and adjoint-based optimization in computational science, encompassing a family of algorithms including standard segment-based approaches, binomial (Revolve) checkpointing, two-level (double) checkpointing, call-tree-aware profiling strategies, and recent innovations for distributed and structured state-space models.

1. Foundational Principles and Mathematical Formulation

Gradient checkpointing exploits the fact that during reverse-mode automatic differentiation (AD)—essential for training or gradient-based optimization—each backwards step requires access to the forward pass state at one or more time points or layers. Naïvely, this requires storing the full computation history, which is often intractable in memory. Instead, checkpointing algorithms retain only a small set of "checkpoints" (snapshots of model state) and reconstruct the missing intermediates by replaying segments of the forward computation during the backward sweep.

The generic problem is to minimize the total recomputation cost TtotalT_{\rm total} subject to a memory constraint: minpolicy  Ttotal=Tfwd+Tbwd+αNrefwdsubject toMpeakMmax\min_{\rm policy}\;T_{\rm total} = T_{\rm fwd} + T_{\rm bwd} + \alpha\,N_{\rm refwd}\quad\text{subject to}\quad M_{\rm peak} \leq M_{\max} with TfwdT_{\rm fwd}, TbwdT_{\rm bwd} the forward/backward compute time, NrefwdN_{\rm refwd} the extra forward replays, and α\alpha a hardware or algorithm-specific factor for recomputation cost (Bencheikh et al., 2024).

The "Revolve" (binomial checkpointing) algorithm formalizes the trade-off with a recurrence: T(N,S)=min1kN[k+T(Nk,S)+T(k,S1)]T(N, S) = \min_{1 \leq k \leq N} \left[ k + T(N-k, S) + T(k, S-1) \right] where NN is the number of steps/layers, SS is the checkpoint limit, and T(N,S)T(N, S) the minimal total forward steps needed to support both forward and backward runs (Kukreja et al., 2018).

2. Checkpoint Scheduling and Key Algorithms

The simplest form is segment-based checkpointing, partitioning the computation into minpolicy  Ttotal=Tfwd+Tbwd+αNrefwdsubject toMpeakMmax\min_{\rm policy}\;T_{\rm total} = T_{\rm fwd} + T_{\rm bwd} + \alpha\,N_{\rm refwd}\quad\text{subject to}\quad M_{\rm peak} \leq M_{\max}0 blocks. Checkpoints are stored at block boundaries, and during the backward pass, missing intermediates within each block are recomputed from the latest checkpoint.

Revolve (binomial checkpointing) computes, via dynamic programming, the optimal placement of checkpoints given minpolicy  Ttotal=Tfwd+Tbwd+αNrefwdsubject toMpeakMmax\min_{\rm policy}\;T_{\rm total} = T_{\rm fwd} + T_{\rm bwd} + \alpha\,N_{\rm refwd}\quad\text{subject to}\quad M_{\rm peak} \leq M_{\max}1 and minpolicy  Ttotal=Tfwd+Tbwd+αNrefwdsubject toMpeakMmax\min_{\rm policy}\;T_{\rm total} = T_{\rm fwd} + T_{\rm bwd} + \alpha\,N_{\rm refwd}\quad\text{subject to}\quad M_{\rm peak} \leq M_{\max}2, minimizing recomputation. This schedule is piecewise-constant and follows a binomial coefficient pattern. The optimality holds for uniform step cost and pre-known step count, and can be abstracted in high-level frameworks via APIs such as pyRevolve (Kukreja et al., 2018).

For adjoint or multi-level computational workflows (e.g., call-tree structures), recent work applies profiling and heuristics to identify which function calls or subroutines should be checkpointed given a Pareto trade-off between run time and stack storage. Profiling-guided greedy algorithms use collected statistics (e.g., incremental cost/benefit for each candidate checkpoint) to evolve the checkpoint configuration toward the Pareto frontier (Hascoët et al., 2024).

For distributed memory and accelerator-rich hardware, such as the Graphcore IPU, "Double Checkpointing" introduces two tiers: "remote" checkpoints (off-chip, large, slow) and "local" (on-tile, small, fast), balancing memory hierarchy and minimizing recomputation overhead (Bencheikh et al., 2024).

3. Multi-Axis and Structured Checkpointing in Deep Models

With the rise of deep models for long-sequence processing (e.g., video), Multi-Axis Gradient Checkpointing (MA-GC) extends single-axis checkpointing to multiple axes simultaneously. This is critical for architectures like state-space models (SSMs), which are structured such that each layer in depth and each step in the temporal sequence can be treated as separate recomputation axes.

The MA-GC algorithm generates a two-dimensional grid of checkpoints (across both layer and sequence axes), storing only grid points and their cell boundaries. The memory complexity is given by: minpolicy  Ttotal=Tfwd+Tbwd+αNrefwdsubject toMpeakMmax\min_{\rm policy}\;T_{\rm total} = T_{\rm fwd} + T_{\rm bwd} + \alpha\,N_{\rm refwd}\quad\text{subject to}\quad M_{\rm peak} \leq M_{\max}3 where minpolicy  Ttotal=Tfwd+Tbwd+αNrefwdsubject toMpeakMmax\min_{\rm policy}\;T_{\rm total} = T_{\rm fwd} + T_{\rm bwd} + \alpha\,N_{\rm refwd}\quad\text{subject to}\quad M_{\rm peak} \leq M_{\max}4 is the number of layers, minpolicy  Ttotal=Tfwd+Tbwd+αNrefwdsubject toMpeakMmax\min_{\rm policy}\;T_{\rm total} = T_{\rm fwd} + T_{\rm bwd} + \alpha\,N_{\rm refwd}\quad\text{subject to}\quad M_{\rm peak} \leq M_{\max}5 is sequence length, and minpolicy  Ttotal=Tfwd+Tbwd+αNrefwdsubject toMpeakMmax\min_{\rm policy}\;T_{\rm total} = T_{\rm fwd} + T_{\rm bwd} + \alpha\,N_{\rm refwd}\quad\text{subject to}\quad M_{\rm peak} \leq M_{\max}6, minpolicy  Ttotal=Tfwd+Tbwd+αNrefwdsubject toMpeakMmax\min_{\rm policy}\;T_{\rm total} = T_{\rm fwd} + T_{\rm bwd} + \alpha\,N_{\rm refwd}\quad\text{subject to}\quad M_{\rm peak} \leq M_{\max}7 are the checkpoint intervals for layer and sequence, respectively. Optimally, for minpolicy  Ttotal=Tfwd+Tbwd+αNrefwdsubject toMpeakMmax\min_{\rm policy}\;T_{\rm total} = T_{\rm fwd} + T_{\rm bwd} + \alpha\,N_{\rm refwd}\quad\text{subject to}\quad M_{\rm peak} \leq M_{\max}8 (typical in video models), minpolicy  Ttotal=Tfwd+Tbwd+αNrefwdsubject toMpeakMmax\min_{\rm policy}\;T_{\rm total} = T_{\rm fwd} + T_{\rm bwd} + \alpha\,N_{\rm refwd}\quad\text{subject to}\quad M_{\rm peak} \leq M_{\max}9, enabling TfwdT_{\rm fwd}0 longer sequences with equivalent memory (Lee et al., 2024).

In practice, MA-GC is realized within models such as Video-Ma²mba, where SSMs replace traditional attention for scalability, and the checkpoint scheme is aligned with kernel fusion for high hardware efficiency.

4. Trade-Offs: Memory, Computation, and Algorithmic Complexity

The fundamental exchange in gradient checkpointing is between reduced memory usage and increased compute time (due to redundant forward computations). Key results for various checkpointing schemes are as follows:

Scheme Peak Memory Complexity Time Overhead Applicability
Full Storage (no checkpointing) TfwdT_{\rm fwd}1 All models, infeasible at scale
Single-Axis Checkpointing TfwdT_{\rm fwd}2 TfwdT_{\rm fwd}31.5–2× RNNs, simple deep nets
Revolve (Optimal Binomial) TfwdT_{\rm fwd}4 DP-minimal Uniform-cost, serial steps
Double Checkpointing TfwdT_{\rm fwd}5 TfwdT_{\rm fwd}61.1× Sparse/recurrent, hierarchical mem.
Multi-Axis (MA-GC) TfwdT_{\rm fwd}7 TfwdT_{\rm fwd}81.35× Large SSMs, video models (Lee et al., 2024)

In practical deployments:

  • Double Checkpointing on IPUs for sparse RNNs or SNNs enables TfwdT_{\rm fwd}9 longer sequences or TbwdT_{\rm bwd}0 larger models compared to standard BPTT, with only TbwdT_{\rm bwd}1 time overhead (Bencheikh et al., 2024).
  • For MA-GC, training with TbwdT_{\rm bwd}2K tokens with similar memory as TbwdT_{\rm bwd}3K un-checkpointed is achieved, extending context by TbwdT_{\rm bwd}4 at TbwdT_{\rm bwd}5 throughput penalty (Lee et al., 2024).
  • Greedy profiling-guided call-tree checkpointing shows TbwdT_{\rm bwd}6 speedup at similar stack size, or TbwdT_{\rm bwd}7 time reduction at TbwdT_{\rm bwd}8 extra memory compared to naive adjoint code (Hascoët et al., 2024).

5. Implementation Paradigms and High-Level Interfaces

Gradient checkpointing algorithms are accessible via high-level abstractions integrated into scientific and machine learning frameworks.

  • The Revolve algorithm is wrapped via pyRevolve, which abstracts checkpoint logic from the user, enabling seamless integration into Python-based domain-specific languages (DSLs) for PDEs (e.g., Devito) or deep network frameworks (Kukreja et al., 2018). The user provides forward and reverse operators along with a checkpoint object, and the schedule of checkpoint actions is automatically determined.
  • For profiling-guided checkpointing (e.g., in Tapenade), the AD tool is instrumented to record execution costs and memory footprints at each checkpointing candidate, informing the greedy selection algorithm for checkpoint inhibition or activation (Hascoët et al., 2024).
  • Custom kernel fusion and block-aligned checkpoint scheduling are necessary when deploying MA-GC on hardware accelerators, synchronizing checkpoint intervals with hardware-friendly batch sizes (Lee et al., 2024).

6. Empirical Performance and Application Domains

Gradient checkpointing has been empirically validated across a spectrum of application settings:

  • In seismic inversion (Devito), optimal checkpointing reduced memory from 80 GB to 10 GB at the cost of a TbwdT_{\rm bwd}9 longer runtime, with identical gradients and stable optimization (Kukreja et al., 2018).
  • For climate and ocean models (MITgcm), profiling-guided checkpointing achieved up to NrefwdN_{\rm refwd}0 runtime speedup or NrefwdN_{\rm refwd}1 time reduction, with the performance curve lying close to optimal relative to 250 random checkpointing configurations (Hascoët et al., 2024).
  • For Mamba-2 SSMs in long-form video understanding, MA-GC enabled end-to-end training and inference on million-token sequences on a single modern GPU, outperforming baseline large multimodal models on Video-MME and LongVideoBench tasks (Lee et al., 2024).
  • Double Checkpointing on IPU hardware allowed training of SNNs with sequence length NrefwdN_{\rm refwd}2 and NrefwdN_{\rm refwd}3 model size compared to BPTT baseline, with NrefwdN_{\rm refwd}4 extra runtime (Bencheikh et al., 2024).

Performance results demonstrate that checkpointing is a flexible and essential tool for scaling high-performance scientific computing and deep learning workflows under fixed or distributed memory constraints.

7. Limitations, Challenges, and Future Directions

Most optimal checkpointing algorithms, such as Revolve, require advance knowledge of step/layer count and assume uniform computation cost per step. Run-time variability, multi-level memory (e.g., RAM+SSD+NVRAM), and parallel execution present challenges for current methods (Kukreja et al., 2018, Hascoët et al., 2024). Advanced AD optimizations (e.g., dead code elimination) can also obscure the memory model during profiling, complicating accurate prediction (Hascoët et al., 2024).

Furthermore, non-uniform architectures (dynamic neural networks, adaptive time-stepping) necessitate online or adaptive checkpoint scheduling, an active area of research. No unified model currently exists to trade off between loop-based (temporal) and call-tree checkpointing under a global memory budget (Hascoët et al., 2024). Efficient integration with modern accelerator kernels requires further co-design with system-level schedulers and compiler frameworks (Lee et al., 2024).

A plausible implication is that the increasing heterogeneity and scale of modern computing hardware will drive further innovation in hierarchical, distributed, and dynamically adaptive checkpointing strategies, with particular emphasis on minimizing both recomputation and memory transfers in real-world, large-scale models.

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Gradient Checkpointing.