Papers
Topics
Authors
Recent
Search
2000 character limit reached

FlashAttention-2 JVP Kernel

Updated 6 November 2025
  • FlashAttention-2 JVP Kernel is a GPU-optimized algorithm that fuses Jacobian-vector product computation into attention layers, enabling efficient scaling of diffusion models.
  • It partitions queries, keys, values, and their tangents into blocks using a Triton-based streaming approach that maintains linear memory usage and high throughput.
  • The kernel supports full distributed training (FSDP/CP) with robust numerical stability, enabling production-scale T2I and T2V diffusion tasks with models exceeding 10B parameters.

A FlashAttention-2 JVP kernel is a parallelism-compatible, GPU-optimized algorithm and implementation for computing the Jacobian-vector product (JVP) through attention layers, crucial for scaling continuous-time consistency models (sCM) and score-regularized continuous-time consistency models (rCM) to large-scale diffusion tasks involving models exceeding 10 billion parameters and high-dimensional inputs such as video. The kernel directly integrates JVP computation into the highly efficient FlashAttention-2 memory/compute pipeline, enabling training of T2I and T2V models at unprecedented parameter counts and sequence lengths while maintaining throughput, memory efficiency, and numerical stability (Zheng et al., 9 Oct 2025).

1. Role and Motivation in Diffusion Distillation

Large-scale distillation of diffusion models in continuous time, specifically sCM and its rCM extension, requires the instantaneous time derivative of the student or teacher network’s output along the ODE trajectory: ddtfθ(xt,t)=∇xtfθ(xt,t)⋅dxtdt+∂tfθ(xt,t)\frac{d}{dt} f_\theta(x_t, t) = \nabla_{x_t} f_\theta(x_t, t) \cdot \frac{dx_t}{dt} + \partial_t f_\theta(x_t, t) This expression defines a matrix-free Jacobian-vector product (JVP) with respect to both spatial and temporal inputs. Materializing the full Jacobian is computationally infeasible for high-dimensional outputs (e.g., image/video sequences). Traditional autodiff utilities such as torch.func.jvp are not compatible with distributed data-parallel or attention-optimized architectures, making a bespoke attention JVP kernel necessary. Furthermore, modern diffusion backbones are attention-heavy, and video generation requires very long sequences, making direct JVP support within the attention kernel mission-critical (Zheng et al., 9 Oct 2025).

2. Core Mathematical Formulation

In standard attention with Q∈RN1×dQ \in \mathbb{R}^{N_1 \times d}, K∈RN2×dK \in \mathbb{R}^{N_2 \times d}, V∈RN2×dV \in \mathbb{R}^{N_2 \times d}: S=QK⊤,P=softmax(S),O=PVS = QK^\top, \qquad P = \mathrm{softmax}(S), \qquad O = PV Let input tangents tQtQ, tKtK, tVtV represent perturbations for JVP computation. The aim is: tO=dOdQtQ+dOdKtK+dOdVtVtO = \frac{dO}{dQ}tQ + \frac{dO}{dK}tK + \frac{dO}{dV}tV Concretely, matrix derivatives are: tS=tQK⊤+QtK⊤ tP=P⊙tS−P⊙((P⊙tS)1N21N2⊤) tO=PtV+[(P⊙tS)V]−diag(rowsum(P⊙tS))O\begin{align*} tS & = tQ K^\top + Q tK^\top \ tP & = P \odot tS - P \odot \left((P \odot tS)\mathbf{1}_{N_2} \mathbf{1}_{N_2}^\top\right) \ tO & = PtV + [(P \odot tS)V] - \mathrm{diag}(\mathrm{rowsum}(P \odot tS)) O \end{align*} The structure enables the tangent computation to be fused into the streaming block-wise FlashAttention-2 forward pass (Zheng et al., 9 Oct 2025).

3. Implementation and Kernel Design

Triton-based Kernel Construction

  • The kernel is implemented in Triton, extending FlashAttention-2 such that primal (Q∈RN1×dQ \in \mathbb{R}^{N_1 \times d}0) and tangent (Q∈RN1×dQ \in \mathbb{R}^{N_1 \times d}1) outputs are produced simultaneously during block-wise streaming.
  • Queries, keys, values, and their tangents are partitioned along sequence and batch/head axes.
  • For each tile/block the kernel:
    • Loads local blocks of Q∈RN1×dQ \in \mathbb{R}^{N_1 \times d}2 and Q∈RN1×dQ \in \mathbb{R}^{N_1 \times d}3.
    • Computes Q∈RN1×dQ \in \mathbb{R}^{N_1 \times d}4, Q∈RN1×dQ \in \mathbb{R}^{N_1 \times d}5, applies numerically stable online softmax normalization for Q∈RN1×dQ \in \mathbb{R}^{N_1 \times d}6, and propagates the tangents through the normalization.
    • Updates Q∈RN1×dQ \in \mathbb{R}^{N_1 \times d}7 and Q∈RN1×dQ \in \mathbb{R}^{N_1 \times d}8 with local fusions.
  • The streaming approach preserves linear memory usage, avoiding quadratic scaling and enabling extremely long input sequences (Zheng et al., 9 Oct 2025).

Distributed Parallelism and Numerical Stability

  • Fully compatible with Fully Sharded Data Parallelism (FSDP) and Context Parallelism (CP): attention blocks internally manage JVP, preserving layer-wise sharding and distributed slices.
  • Tangent propagation pattern mirrors the block-wise partitioning of primals, ensuring all-gather/reduce operations remain valid.
  • Accumulators and softmax normalization use full-precision arithmetic to mitigate numerical drift; time-embedding layers can optionally operate in FP32 for high-precision time derivatives.
  • Provides a drop-in interface for self- and cross-attention, with both primal and tangent input/output, controlled via a withT option (Zheng et al., 9 Oct 2025).

Pseudocode Outline (as given):

K∈RN2×dK \in \mathbb{R}^{N_2 \times d}6

Table: FlashAttention-2 JVP Kernel Implementation Features

Feature Description
Kernel language Triton
Parallelism support FSDP, CP full compatibility
Numerical handling Fused, full-precision softmax and tangent arithmetic
Interface Self-/cross-attention, withT flag for primal/tangent signaling
Memory scaling Linear (O(N)), no materialized Jacobians
Block design Streaming, tiling, blockwise updates for all intermediates

4. Integration in Neural Architectures and Distributed Training

  • Within the network, all layers accept and propagate primal and tangent pairs; for most layers, torch.func.jvp suffices, but for attention, the custom FlashAttention-2 JVP kernel is invoked.
  • All distributed primitives (sharding, all-gather/reduce) are maintained unchanged, as the tangent and primal streaming/partitioning are isomorphic.
  • Enables context parallel scaling and FSDP for T2I and T2V models at scales infeasible with standard JVP approaches.
  • Tangent error handling and normalization strategies mitigate the increased numerical instability of JVP, especially at low precision or extreme sequence length (Zheng et al., 9 Oct 2025).

5. Empirical Performance and Benchmarks

The kernel was validated as an enabler for large-scale, high-dimensional diffusion model distillation, specifically:

  • rCM/FlashAttention-2 JVP enabled sCM/rCM training on Cosmos-Predict2 up to 14B parameters (text-to-image, Q∈RN1×dQ \in \mathbb{R}^{N_1 \times d}9 px) and Wan2.1 up to 14B parameters (text-to-video, up to K∈RN2×dK \in \mathbb{R}^{N_2 \times d}0 frames; 5s HD video).
  • Achieved strong scaling with context parallelism and FSDP.
  • Supported large-batch, high-resolution, long-duration video inputs previously unattainable using prior JVP solutions.
  • Sampling images in K∈RN2×dK \in \mathbb{R}^{N_2 \times d}1 steps and up to K∈RN2×dK \in \mathbb{R}^{N_2 \times d}2 acceleration over teacher models; videos sampled in K∈RN2×dK \in \mathbb{R}^{N_2 \times d}3 steps at scale.
  • No effective increase in memory usage relative to FlashAttention-2 standard forward; tangent computation is fully fused.
  • Ablations demonstrated that alternate (non-fused, non-robust) JVP approaches led to numerical instability at scale (see experiments in Appendix and Figure JVP-errors) (Zheng et al., 9 Oct 2025).

6. Impact on Consistency Modeling and Diffusion Training

  • The kernel is a necessary prerequisite for practical scaling of sCM and rCM: without it, training at large scale is infeasible due to memory, numerical, and parallelism barriers.
  • The fused kernel allows the rCM objective, which interleaves consistency loss (JVP-requiring) and score-regularized loss (forward-only), to be optimized efficiently for very large models and high-dimensional tasks.
  • Enables throughput and hardware utilization parity with regular FlashAttention-2 (no additional tracing or activation storage overhead).
  • Used to distill state-of-the-art T2I and T2V models (Cosmos-Predict2, Wan2.1) in large clusters with minimal codebase and hardware changes, constituting, as stated, the first production-grade extension of continuous-time consistency to these domains (Zheng et al., 9 Oct 2025).

7. Mathematical and Implementation Summary

The kernel implements—inside the attention layer’s main loop—the critical JVP term in the sCM and rCM loss: K∈RN2×dK \in \mathbb{R}^{N_2 \times d}4 with

K∈RN2×dK \in \mathbb{R}^{N_2 \times d}5

The dominant JVP term is evaluated by the FlashAttention-2 JVP kernel, streamed through all attention-heavy subnetworks during distillation (Zheng et al., 9 Oct 2025).

8. Novelty and Significance

The parallelism-compatible FlashAttention-2 JVP kernel constitutes the first efficient, scalable, numerically robust JVP implementation for attention applicable to production-scale T2I and T2V diffusion models. It supports full distributed training (FSDP/CP), fuses tangent propagation into the attention memory stream, and avoids both quadratic memory scaling and the efficiency bottlenecks of prior approaches. Its deployment enables sCM and rCM distillation to 10B+ models, with no need for architectural changes or diminished performance. This kernel is a foundational technology advancing large-scale generative model distillation grounded in continuous-time consistency theory (Zheng et al., 9 Oct 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to FlashAttention-2 JVP Kernel.