Score-Regularized Continuous-Time Consistency Models (rCM)
- The paper introduces a dual-objective rCM framework that combines forward-KL for mode coverage and reverse-KL for mode seeking, balancing diversity and detail.
- It employs custom JVP kernels and parallelism strategies like FSDP to enable efficient few-step sampling for models exceeding 10 billion parameters.
- Experimental results demonstrate a 15x to 50x speedup with maintained sample fidelity, showing significant improvements in both T2I and T2V generation quality.
Score-Regularized Continuous-Time Consistency Models (rCM) are a framework for scalable, high-quality distillation of large-scale diffusion models—especially in the application domains of text-to-image (T2I) and text-to-video (T2V) generation. rCM unifies forward-KL (mode-covering) and reverse-KL (mode-seeking) objectives into a practical distillation methodology that preserves sample diversity and high-fidelity detail, while enabling few-step sampling for models exceeding 10 billion parameters (Zheng et al., 9 Oct 2025).
1. Mathematical Foundations
Continuous-Time Consistency Models (sCM)
Given a pre-trained diffusion "teacher" model defined by the probability-flow ODE
which maps data to terminal noise at , a student consistency model is trained to predict the original directly from at arbitrary time . Discrete-time consistency minimizes prediction error over intermediate states, while the continuous-time limit yields the sCM objective: with and . The differential
0
necessitates Jacobian-vector products (JVPs) for efficient computation.
Limitations of sCM
sCM objectives are fundamentally forward-KL-divergence minimizing (teacher 1 student), which enforces mode coverage but can accumulate errors and cause smoothing in reconstructed details. Specifically:
- Error Accumulation: Small inaccuracies at earlier diffusion times are amplified by self-feedback, resulting in blurred fine structures and reduced temporal coherence in video synthesis.
- Mode-Covering: The forward-KL focus penalizes missed data modes, broadening density at the cost of sharpness and high-frequency features.
2. Score Regularization and The rCM Loss
rCM augments sCM with a mode-seeking reverse-KL regularization, inspired by Distribution-Matching Distillation (DMD). Defining 2 and 3 as the time-4 marginals for student and teacher,
5
the reverse-KL component is approximated via a “fake-score” network 6 trained using flow matching, leading to the DMD loss: 7 where “sg” denotes stop-gradient.
The combined rCM objective is: 8 9 preserves diversity, while 0 restores sharp detail via mode-seeking correction.
3. Implementation Techniques
FlashAttention-2 JVP Kernel
A custom Triton-based kernel fuses the JVP operation into the block-wise, tiled FlashAttention-2 forward pass. The implementation:
- Accumulates both standard attention 1 and JVP outputs 2, where 3.
- Extends to both self- and cross-attention, preserving parallelism efficiency.
Parallelism Strategies
- FSDP: Each layer exposes a “JVP-aware” interface, enabling tangent and forward computations without recomputing gradients.
- Ulysses Context-Parallelism: Query, key, and value (QKV) tensors are sharded and communicated by all-to-all protocols, with tangent vectors following the same pattern.
Training and Inference Protocol
A generator/critic step alternates:
- Generator: Samples 4, computes 5 and JVP, forms 6, and—after bootstrap—forms 7 using backward simulation of the student.
- Critic: Updates 8 by matching student rollouts via 9.
- Inference: Few-step sampling (0 steps), iterating 1.
The stack employs PyTorch 2.0 (torch.func.jvp), custom Triton kernels, FSDP, Ulysses context-parallelism, BF16 mixed precision, and A100/H100 GPUs.
4. Experimental Results
Experiments span the Cosmos-Predict2 and Wan2.1 model families in T2I and T2V, with up to 14B parameters and 5-second video outputs.
| Model | Params | NFE | GenEval (T2I) | VBench (T2V) |
|---|---|---|---|---|
| Cosmos-Predict2 (teacher) | 14 B | 70 | 0.84 | — |
| + DMD2 | 14 B | 4 | 0.80 | 84.6 |
| + rCM | 14 B | 4 | 0.83 | 84.9 |
- Speedup: rCM enables 2 acceleration (number of function evaluations, NFE) relative to teacher models.
- Diversity: Unlike DMD2, rCM maintains mode coverage and avoids mode collapse in T2V (e.g., varied object poses).
- Qualitative Quality: rCM sharply renders fine details (e.g., text in T2I, frame-to-frame coherence in T2V).
5. Theoretical and Practical Implications
rCM leverages the complementary strengths of forward-KL (mode-covering) and reverse-KL (mode-seeking) training signals. This dual-objective strategy counters errors and smoothing artifacts observed with sCM alone, while preserving the ability to generate high-diversity samples in few steps. JVP-based self-feedback remains a computational bottleneck and source of numerical instability, especially in low precision; proposed remedies include higher-order flow formulations and Hutchinson-trace approximations.
Implementation complexity increases due to custom kernel requirements and parallelism adaptation, yet rCM achieves state-of-the-art quality without requiring adversarial finetuning or extensive hyperparameter search. A residual performance gap is noted for single-step T2V sampling.
6. Extensions and Future Work
Potential avenues for improving rCM include:
- Adaptive regularization scheduling (time-dependent 3).
- Integration with multi-step consistency trajectory approaches (e.g., sCTM).
- Incorporating adversarial objectives to further improve sample realism.
- Addressing the efficiency and error limitations intrinsic to JVP computation.
A plausible implication is that rCM’s unification of consistency and score distillation principles may generalize to other generative model distillation regimes, especially where balancing sample sharpness with output diversity remains critical (Zheng et al., 9 Oct 2025).