Papers
Topics
Authors
Recent
2000 character limit reached

FlashAttention-2 JVP Kernel

Updated 6 November 2025
  • FlashAttention-2 JVP Kernel is a GPU-optimized algorithm that fuses Jacobian-vector product computation into attention layers, enabling efficient scaling of diffusion models.
  • It partitions queries, keys, values, and their tangents into blocks using a Triton-based streaming approach that maintains linear memory usage and high throughput.
  • The kernel supports full distributed training (FSDP/CP) with robust numerical stability, enabling production-scale T2I and T2V diffusion tasks with models exceeding 10B parameters.

A FlashAttention-2 JVP kernel is a parallelism-compatible, GPU-optimized algorithm and implementation for computing the Jacobian-vector product (JVP) through attention layers, crucial for scaling continuous-time consistency models (sCM) and score-regularized continuous-time consistency models (rCM) to large-scale diffusion tasks involving models exceeding 10 billion parameters and high-dimensional inputs such as video. The kernel directly integrates JVP computation into the highly efficient FlashAttention-2 memory/compute pipeline, enabling training of T2I and T2V models at unprecedented parameter counts and sequence lengths while maintaining throughput, memory efficiency, and numerical stability (Zheng et al., 9 Oct 2025).

1. Role and Motivation in Diffusion Distillation

Large-scale distillation of diffusion models in continuous time, specifically sCM and its rCM extension, requires the instantaneous time derivative of the student or teacher network’s output along the ODE trajectory: ddtfθ(xt,t)=∇xtfθ(xt,t)⋅dxtdt+∂tfθ(xt,t)\frac{d}{dt} f_\theta(x_t, t) = \nabla_{x_t} f_\theta(x_t, t) \cdot \frac{dx_t}{dt} + \partial_t f_\theta(x_t, t) This expression defines a matrix-free Jacobian-vector product (JVP) with respect to both spatial and temporal inputs. Materializing the full Jacobian is computationally infeasible for high-dimensional outputs (e.g., image/video sequences). Traditional autodiff utilities such as torch.func.jvp are not compatible with distributed data-parallel or attention-optimized architectures, making a bespoke attention JVP kernel necessary. Furthermore, modern diffusion backbones are attention-heavy, and video generation requires very long sequences, making direct JVP support within the attention kernel mission-critical (Zheng et al., 9 Oct 2025).

2. Core Mathematical Formulation

In standard attention with Q∈RN1×dQ \in \mathbb{R}^{N_1 \times d}, K∈RN2×dK \in \mathbb{R}^{N_2 \times d}, V∈RN2×dV \in \mathbb{R}^{N_2 \times d}: S=QK⊤,P=softmax(S),O=PVS = QK^\top, \qquad P = \mathrm{softmax}(S), \qquad O = PV Let input tangents tQtQ, tKtK, tVtV represent perturbations for JVP computation. The aim is: tO=dOdQtQ+dOdKtK+dOdVtVtO = \frac{dO}{dQ}tQ + \frac{dO}{dK}tK + \frac{dO}{dV}tV Concretely, matrix derivatives are: tS=tQK⊤+QtK⊤ tP=P⊙tS−P⊙((P⊙tS)1N21N2⊤) tO=PtV+[(P⊙tS)V]−diag(rowsum(P⊙tS))O\begin{align*} tS & = tQ K^\top + Q tK^\top \ tP & = P \odot tS - P \odot \left((P \odot tS)\mathbf{1}_{N_2} \mathbf{1}_{N_2}^\top\right) \ tO & = PtV + [(P \odot tS)V] - \mathrm{diag}(\mathrm{rowsum}(P \odot tS)) O \end{align*} The structure enables the tangent computation to be fused into the streaming block-wise FlashAttention-2 forward pass (Zheng et al., 9 Oct 2025).

3. Implementation and Kernel Design

Triton-based Kernel Construction

  • The kernel is implemented in Triton, extending FlashAttention-2 such that primal (OO) and tangent (tOtO) outputs are produced simultaneously during block-wise streaming.
  • Queries, keys, values, and their tangents are partitioned along sequence and batch/head axes.
  • For each tile/block the kernel:
    • Loads local blocks of Q,tQQ, tQ and K,tK,V,tVK, tK, V, tV.
    • Computes S=QK⊤S = QK^\top, tS=tQK⊤+QtK⊤tS = tQ K^\top + Q tK^\top, applies numerically stable online softmax normalization for SS, and propagates the tangents through the normalization.
    • Updates OO and tOtO with local fusions.
  • The streaming approach preserves linear memory usage, avoiding quadratic scaling and enabling extremely long input sequences (Zheng et al., 9 Oct 2025).

Distributed Parallelism and Numerical Stability

  • Fully compatible with Fully Sharded Data Parallelism (FSDP) and Context Parallelism (CP): attention blocks internally manage JVP, preserving layer-wise sharding and distributed slices.
  • Tangent propagation pattern mirrors the block-wise partitioning of primals, ensuring all-gather/reduce operations remain valid.
  • Accumulators and softmax normalization use full-precision arithmetic to mitigate numerical drift; time-embedding layers can optionally operate in FP32 for high-precision time derivatives.
  • Provides a drop-in interface for self- and cross-attention, with both primal and tangent input/output, controlled via a withT option (Zheng et al., 9 Oct 2025).

Pseudocode Outline (as given):

1
2
3
4
5
6
7
8
9
for block_i in row_blocks:  # Stream over query
    # Load Q, tQ block_i ...
    for block_j in col_blocks:  # Stream over key/value
        # Load K, tK, V, tV block_j
        # Compute S = QK^T, tS = tQK^T + Q tK^T
        # Softmax normalization (numerically stable)
        # Compute O block, tO block updates
    # Normalize and finalize O, tO block
    # Write O, tO to output

Table: FlashAttention-2 JVP Kernel Implementation Features

Feature Description
Kernel language Triton
Parallelism support FSDP, CP full compatibility
Numerical handling Fused, full-precision softmax and tangent arithmetic
Interface Self-/cross-attention, withT flag for primal/tangent signaling
Memory scaling Linear (O(N)), no materialized Jacobians
Block design Streaming, tiling, blockwise updates for all intermediates

4. Integration in Neural Architectures and Distributed Training

  • Within the network, all layers accept and propagate primal and tangent pairs; for most layers, torch.func.jvp suffices, but for attention, the custom FlashAttention-2 JVP kernel is invoked.
  • All distributed primitives (sharding, all-gather/reduce) are maintained unchanged, as the tangent and primal streaming/partitioning are isomorphic.
  • Enables context parallel scaling and FSDP for T2I and T2V models at scales infeasible with standard JVP approaches.
  • Tangent error handling and normalization strategies mitigate the increased numerical instability of JVP, especially at low precision or extreme sequence length (Zheng et al., 9 Oct 2025).

5. Empirical Performance and Benchmarks

The kernel was validated as an enabler for large-scale, high-dimensional diffusion model distillation, specifically:

  • rCM/FlashAttention-2 JVP enabled sCM/rCM training on Cosmos-Predict2 up to 14B parameters (text-to-image, 1360×7681360 \times 768 px) and Wan2.1 up to 14B parameters (text-to-video, up to 832×480×81832 \times 480 \times 81 frames; 5s HD video).
  • Achieved strong scaling with context parallelism and FSDP.
  • Supported large-batch, high-resolution, long-duration video inputs previously unattainable using prior JVP solutions.
  • Sampling images in 1∼41\sim4 steps and up to 50×50\times acceleration over teacher models; videos sampled in 2∼42\sim4 steps at scale.
  • No effective increase in memory usage relative to FlashAttention-2 standard forward; tangent computation is fully fused.
  • Ablations demonstrated that alternate (non-fused, non-robust) JVP approaches led to numerical instability at scale (see experiments in Appendix and Figure JVP-errors) (Zheng et al., 9 Oct 2025).

6. Impact on Consistency Modeling and Diffusion Training

  • The kernel is a necessary prerequisite for practical scaling of sCM and rCM: without it, training at large scale is infeasible due to memory, numerical, and parallelism barriers.
  • The fused kernel allows the rCM objective, which interleaves consistency loss (JVP-requiring) and score-regularized loss (forward-only), to be optimized efficiently for very large models and high-dimensional tasks.
  • Enables throughput and hardware utilization parity with regular FlashAttention-2 (no additional tracing or activation storage overhead).
  • Used to distill state-of-the-art T2I and T2V models (Cosmos-Predict2, Wan2.1) in large clusters with minimal codebase and hardware changes, constituting, as stated, the first production-grade extension of continuous-time consistency to these domains (Zheng et al., 9 Oct 2025).

7. Mathematical and Implementation Summary

The kernel implements—inside the attention layer’s main loop—the critical JVP term in the sCM and rCM loss: LsCM(θ)=Ex0,ξ,t[∥fθ(xt,t)−fθ−(xt,t)−w(t)ddtfθ−(xt,t)∥w(t)ddtfθ−(xt,t)∥22+c∥22],\mathcal{L}_{\text{sCM}}(\theta) = \mathbb{E}_{x_0, \xi, t}\left[\Big\| f_\theta(x_t, t) - f_{\theta^-}(x_t, t) - \frac{w(t)\frac{d}{dt} f_{\theta^-}(x_t, t)}{\| w(t)\frac{d}{dt} f_{\theta^-}(x_t, t) \|_2^2 + c} \Big\|_2^2 \right], with

ddtfθ−(xt,t)=∇xtfθ−(xt,t)⋅dxtdt+∂tfθ−(xt,t)\frac{d}{dt} f_{\theta^-}(x_t, t) = \nabla_{x_t} f_{\theta^-}(x_t, t) \cdot \frac{dx_t}{dt} + \partial_t f_{\theta^-}(x_t, t)

The dominant JVP term is evaluated by the FlashAttention-2 JVP kernel, streamed through all attention-heavy subnetworks during distillation (Zheng et al., 9 Oct 2025).

8. Novelty and Significance

The parallelism-compatible FlashAttention-2 JVP kernel constitutes the first efficient, scalable, numerically robust JVP implementation for attention applicable to production-scale T2I and T2V diffusion models. It supports full distributed training (FSDP/CP), fuses tangent propagation into the attention memory stream, and avoids both quadratic memory scaling and the efficiency bottlenecks of prior approaches. Its deployment enables sCM and rCM distillation to 10B+ models, with no need for architectural changes or diminished performance. This kernel is a foundational technology advancing large-scale generative model distillation grounded in continuous-time consistency theory (Zheng et al., 9 Oct 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)
Slide Deck Streamline Icon: https://streamlinehq.com

Whiteboard

Forward Email Streamline Icon: https://streamlinehq.com

Follow Topic

Get notified by email when new papers are published related to FlashAttention-2 JVP Kernel.