- The paper demonstrates that REPA, by aligning internal representations with those from a frozen SSL encoder, accelerates diffusion transformer training by over 17.5x.
- It shows that integrating REPA significantly improves generative quality, lowering FID scores and achieving state-of-the-art results on ImageNet.
- The approach provides clear guidelines on optimal layer selection and hyperparameter tuning, making it easy to incorporate into existing DiT/SiT pipelines.
This paper introduces REPresentation Alignment (REPA), a simple regularization technique designed to significantly accelerate the training and improve the performance of diffusion transformers like DiT [Peebles2022DiT] and SiT [ma2024sit]. The core idea is that learning high-quality internal representations is a major bottleneck for diffusion models, and leveraging powerful, pre-existing representations from self-supervised learning (SSL) models can alleviate this.
REPA works by adding a regularization term to the standard diffusion model training objective. This term encourages the internal hidden states of the diffusion transformer (processing noisy inputs) to align with the representations of the clean image produced by a strong, frozen, pretrained visual encoder (like DINOv2 [oquab2024dinov]).
How REPA Works (Implementation):
- Input: During training, the diffusion transformer takes a noisy latent input zt=αtz∗+σtϵ, where z∗ is the VAE-encoded clean image latent.
- Target Representation: A pretrained SSL encoder f (e.g., DINOv2) processes the clean image x∗ to produce target patch representations Φ∗=f(x∗)∈RN×D. This encoder f is kept frozen.
- Diffusion Model Representation: An intermediate hidden state Ht=fθ(zt) is extracted from the diffusion transformer at a specific layer (or block index).
- Projection: A trainable projection head hϕ (typically a simple MLP) maps the diffusion model's hidden state Ht to the same dimension as the target representation: hϕ(Ht)∈RN×D.
- Alignment Loss: The REPA loss maximizes the patch-wise similarity between the projected diffusion representation and the target SSL representation:
1
|
\mathcal{L}_{\text{REPA}}(\theta, \phi) = -\mathbb{E}_{\mathbf{z}_\ast, \bm{\epsilon}, t} \left[ \frac{1}{N}\sum_{n=1}^{N} \mathrm{sim}(\mathbf{\Phi}_\ast^{[n]}, h_{\phi}(\mathbf{H}_t^{[n]})) \right] |
Common choices for the similarity function sim
are cosine similarity or the NT-Xent loss [chen2020simple].
- Combined Objective: The final training objective is a weighted sum of the original diffusion loss (e.g., velocity prediction loss Lvelocity or DDPM loss Lsimple) and the REPA loss:
1
|
\mathcal{L} = \mathcal{L}_{\text{diffusion}} + \lambda \mathcal{L}_{\text{REPA}} |
where λ is a hyperparameter controlling the strength of the alignment. Values around 0.5-1.0 were found to be effective.
Key Findings and Practical Implications:
- Accelerated Convergence: REPA dramatically speeds up training. For SiT-XL/2 on ImageNet 256x256, it achieved the FID score of a model trained for 7M steps in less than 400K steps (>17.5x speedup).
- Improved Generation Quality: REPA improves the final FID scores. Applied to SiT-XL/2, it improved FID from 2.06 to 1.80 (with CFG) and achieved a state-of-the-art FID of 1.42 when combined with guidance interval scheduling [kynkaanniemi2024applying].
- Scalability: The benefits of REPA increase with larger diffusion models and stronger pretrained SSL encoders. Aligning with better SSL representations (e.g., DINOv2 vs. MAE) leads to better generative performance.
- Partial Alignment Suffices: Applying the REPA loss only to the hidden states of early-to-mid layers (e.g., layer 8 out of 24/28) was found to be most effective. This suggests early layers learn robust semantics guided by REPA, while later layers refine details.
- Versatility: REPA improves performance across different diffusion transformer architectures (DiT, SiT), training objectives (DDPM, velocity prediction), model sizes, datasets (ImageNet, MS-COCO), resolutions (256x256, 512x512), and tasks (class-conditional, text-to-image).
- Choice of Encoder: While various SSL encoders work (DINOv2, MoCov3, CLIP, etc.), performance generally correlates with the encoder's representational quality (e.g., measured by linear probe accuracy). DINOv2 variants performed very well.
Implementation Considerations:
- Architecture: Designed for transformer-based diffusion models (DiT/SiT) operating in latent space via a pretrained VAE.
- External Encoder: Needs a frozen, pretrained SSL encoder. Ensure patch sizes and positional embeddings are handled correctly (interpolation might be needed if patch counts differ).
- Projection Head: A simple MLP (e.g., 3 layers with SiLU activation) is sufficient.
- Computational Overhead: Adds the cost of a forward pass through the frozen SSL encoder and the projection head during training. The SSL encoder features can potentially be precomputed for the dataset.
- Hyperparameters: Key parameters are the alignment weight λ (e.g., 0.5) and the layer depth for applying the loss (e.g., layer 8).
In summary, REPA provides a practical and highly effective method to improve the training efficiency and generative quality of diffusion transformers by explicitly aligning their internal representations with strong, external visual features learned via self-supervision. Its simplicity makes it relatively easy to integrate into existing DiT/SiT training pipelines.