FlashAttention-2 JVP Kernel
- FlashAttention-2 JVP Kernel is a GPU-optimized algorithm that fuses Jacobian-vector product computation into attention layers, enabling efficient scaling of diffusion models.
- It partitions queries, keys, values, and their tangents into blocks using a Triton-based streaming approach that maintains linear memory usage and high throughput.
- The kernel supports full distributed training (FSDP/CP) with robust numerical stability, enabling production-scale T2I and T2V diffusion tasks with models exceeding 10B parameters.
A FlashAttention-2 JVP kernel is a parallelism-compatible, GPU-optimized algorithm and implementation for computing the Jacobian-vector product (JVP) through attention layers, crucial for scaling continuous-time consistency models (sCM) and score-regularized continuous-time consistency models (rCM) to large-scale diffusion tasks involving models exceeding 10 billion parameters and high-dimensional inputs such as video. The kernel directly integrates JVP computation into the highly efficient FlashAttention-2 memory/compute pipeline, enabling training of T2I and T2V models at unprecedented parameter counts and sequence lengths while maintaining throughput, memory efficiency, and numerical stability (Zheng et al., 9 Oct 2025).
1. Role and Motivation in Diffusion Distillation
Large-scale distillation of diffusion models in continuous time, specifically sCM and its rCM extension, requires the instantaneous time derivative of the student or teacher network’s output along the ODE trajectory:
This expression defines a matrix-free Jacobian-vector product (JVP) with respect to both spatial and temporal inputs. Materializing the full Jacobian is computationally infeasible for high-dimensional outputs (e.g., image/video sequences). Traditional autodiff utilities such as torch.func.jvp are not compatible with distributed data-parallel or attention-optimized architectures, making a bespoke attention JVP kernel necessary. Furthermore, modern diffusion backbones are attention-heavy, and video generation requires very long sequences, making direct JVP support within the attention kernel mission-critical (Zheng et al., 9 Oct 2025).
2. Core Mathematical Formulation
In standard attention with , , : Let input tangents , , represent perturbations for JVP computation. The aim is: Concretely, matrix derivatives are: The structure enables the tangent computation to be fused into the streaming block-wise FlashAttention-2 forward pass (Zheng et al., 9 Oct 2025).
3. Implementation and Kernel Design
Triton-based Kernel Construction
- The kernel is implemented in Triton, extending FlashAttention-2 such that primal () and tangent () outputs are produced simultaneously during block-wise streaming.
- Queries, keys, values, and their tangents are partitioned along sequence and batch/head axes.
- For each tile/block the kernel:
- Loads local blocks of and .
- Computes , , applies numerically stable online softmax normalization for , and propagates the tangents through the normalization.
- Updates and with local fusions.
- The streaming approach preserves linear memory usage, avoiding quadratic scaling and enabling extremely long input sequences (Zheng et al., 9 Oct 2025).
Distributed Parallelism and Numerical Stability
- Fully compatible with Fully Sharded Data Parallelism (FSDP) and Context Parallelism (CP): attention blocks internally manage JVP, preserving layer-wise sharding and distributed slices.
- Tangent propagation pattern mirrors the block-wise partitioning of primals, ensuring all-gather/reduce operations remain valid.
- Accumulators and softmax normalization use full-precision arithmetic to mitigate numerical drift; time-embedding layers can optionally operate in FP32 for high-precision time derivatives.
- Provides a drop-in interface for self- and cross-attention, with both primal and tangent input/output, controlled via a
withToption (Zheng et al., 9 Oct 2025).
Pseudocode Outline (as given):
1 2 3 4 5 6 7 8 9 |
for block_i in row_blocks: # Stream over query # Load Q, tQ block_i ... for block_j in col_blocks: # Stream over key/value # Load K, tK, V, tV block_j # Compute S = QK^T, tS = tQK^T + Q tK^T # Softmax normalization (numerically stable) # Compute O block, tO block updates # Normalize and finalize O, tO block # Write O, tO to output |
Table: FlashAttention-2 JVP Kernel Implementation Features
| Feature | Description |
|---|---|
| Kernel language | Triton |
| Parallelism support | FSDP, CP full compatibility |
| Numerical handling | Fused, full-precision softmax and tangent arithmetic |
| Interface | Self-/cross-attention, withT flag for primal/tangent signaling |
| Memory scaling | Linear (O(N)), no materialized Jacobians |
| Block design | Streaming, tiling, blockwise updates for all intermediates |
4. Integration in Neural Architectures and Distributed Training
- Within the network, all layers accept and propagate primal and tangent pairs; for most layers,
torch.func.jvpsuffices, but for attention, the custom FlashAttention-2 JVP kernel is invoked. - All distributed primitives (sharding, all-gather/reduce) are maintained unchanged, as the tangent and primal streaming/partitioning are isomorphic.
- Enables context parallel scaling and FSDP for T2I and T2V models at scales infeasible with standard JVP approaches.
- Tangent error handling and normalization strategies mitigate the increased numerical instability of JVP, especially at low precision or extreme sequence length (Zheng et al., 9 Oct 2025).
5. Empirical Performance and Benchmarks
The kernel was validated as an enabler for large-scale, high-dimensional diffusion model distillation, specifically:
- rCM/FlashAttention-2 JVP enabled sCM/rCM training on Cosmos-Predict2 up to 14B parameters (text-to-image, px) and Wan2.1 up to 14B parameters (text-to-video, up to frames; 5s HD video).
- Achieved strong scaling with context parallelism and FSDP.
- Supported large-batch, high-resolution, long-duration video inputs previously unattainable using prior JVP solutions.
- Sampling images in steps and up to acceleration over teacher models; videos sampled in steps at scale.
- No effective increase in memory usage relative to FlashAttention-2 standard forward; tangent computation is fully fused.
- Ablations demonstrated that alternate (non-fused, non-robust) JVP approaches led to numerical instability at scale (see experiments in Appendix and Figure JVP-errors) (Zheng et al., 9 Oct 2025).
6. Impact on Consistency Modeling and Diffusion Training
- The kernel is a necessary prerequisite for practical scaling of sCM and rCM: without it, training at large scale is infeasible due to memory, numerical, and parallelism barriers.
- The fused kernel allows the rCM objective, which interleaves consistency loss (JVP-requiring) and score-regularized loss (forward-only), to be optimized efficiently for very large models and high-dimensional tasks.
- Enables throughput and hardware utilization parity with regular FlashAttention-2 (no additional tracing or activation storage overhead).
- Used to distill state-of-the-art T2I and T2V models (Cosmos-Predict2, Wan2.1) in large clusters with minimal codebase and hardware changes, constituting, as stated, the first production-grade extension of continuous-time consistency to these domains (Zheng et al., 9 Oct 2025).
7. Mathematical and Implementation Summary
The kernel implements—inside the attention layer’s main loop—the critical JVP term in the sCM and rCM loss: with
The dominant JVP term is evaluated by the FlashAttention-2 JVP kernel, streamed through all attention-heavy subnetworks during distillation (Zheng et al., 9 Oct 2025).
8. Novelty and Significance
The parallelism-compatible FlashAttention-2 JVP kernel constitutes the first efficient, scalable, numerically robust JVP implementation for attention applicable to production-scale T2I and T2V diffusion models. It supports full distributed training (FSDP/CP), fuses tangent propagation into the attention memory stream, and avoids both quadratic memory scaling and the efficiency bottlenecks of prior approaches. Its deployment enables sCM and rCM distillation to 10B+ models, with no need for architectural changes or diminished performance. This kernel is a foundational technology advancing large-scale generative model distillation grounded in continuous-time consistency theory (Zheng et al., 9 Oct 2025).