Papers
Topics
Authors
Recent
2000 character limit reached

Canonical Latent Representations in Conditional Diffusion Models

Published 11 Jun 2025 in cs.LG and cs.CV | (2506.09955v1)

Abstract: Conditional diffusion models (CDMs) have shown impressive performance across a range of generative tasks. Their ability to model the full data distribution has opened new avenues for analysis-by-synthesis in downstream discriminative learning. However, this same modeling capacity causes CDMs to entangle the class-defining features with irrelevant context, posing challenges to extracting robust and interpretable representations. To this end, we identify Canonical LAtent Representations (CLAReps), latent codes whose internal CDM features preserve essential categorical information while discarding non-discriminative signals. When decoded, CLAReps produce representative samples for each class, offering an interpretable and compact summary of the core class semantics with minimal irrelevant details. Exploiting CLAReps, we develop a novel diffusion-based feature-distillation paradigm, CaDistill. While the student has full access to the training set, the CDM as teacher transfers core class knowledge only via CLAReps, which amounts to merely 10 % of the training data in size. After training, the student achieves strong adversarial robustness and generalization ability, focusing more on the class signals instead of spurious background cues. Our findings suggest that CDMs can serve not just as image generators but also as compact, interpretable teachers that can drive robust representation learning.

Summary

  • The paper introduces Canonical Latent Representations (CLAReps) that extract essential categorical information from conditional diffusion models for robust representation learning.
  • CLARID, the proposed method, projects latent codes to remove extraneous directions, yielding canonical samples with higher class separability as shown by improved NMI scores.
  • CaDistill uses CLAReps from only 10% of the training data to enhance adversarial robustness and generalization of downstream models.

This paper introduces a method to extract core categorical information from Conditional Diffusion Models (CDMs) and leverage it for robust representation learning. The authors identify that while CDMs excel at generating diverse and high-fidelity images, they often entangle class-defining features with irrelevant contextual information. This entanglement poses challenges for extracting robust and interpretable representations for downstream tasks.

The core contributions are:

  1. Canonical Latent Representations (CLAReps): The paper defines CLAReps as latent codes within CDMs whose internal features (Canonical Features) preserve essential categorical information while discarding non-discriminative signals. When decoded, CLAReps produce "Canonical Samples," which are representative images for each class, highlighting core semantics with minimal irrelevant details.
  2. Canonical Latent Representation Identifier (CLARID): This is the proposed method to identify CLAReps. The process involves:
    • Inverting a given training sample x0\bm{x}_0 (belonging to class c\bm{c}) to a latent code xte\bm{x}_{te} at a specific diffusion timestep tet_e.
    • Identifying "extraneous directions" in this latent space. These directions are the right singular vectors of the Jacobian of the CDM's feature extractor fθ,te(xte)\bm{f}_{\theta, te}(\bm{x}_{te}). Moving along these directions changes visual appearance but preserves class identity.
    • Obtaining the CLARep xte~\tilde{\bm{x}_{te}} by projecting xte\bm{x}_{te} onto the subspace orthogonal to these kk extraneous directions: xte~=(I−VkVkT)xte\tilde{\bm{x}_{te}} = (\mathbf{I} - \mathbf{V}_{k} \mathbf{V}_{k}^{\mathsf{T}})\bm{x}_{te}, where Vk\mathbf{V}_k contains the top kk right singular vectors.
    • Decoding xte~\tilde{\bm{x}_{te}} (conditioned on c\bm{c}) to generate a Canonical Sample x0~\tilde{\bm{x}_0}.
    • Extracting Canonical Features from the CDM using xte~\tilde{\bm{x}_{te}} at a feature extraction timestep trt_r.

    The paper provides strategies for selecting optimal tet_e (projection timestep) and kk (number of extraneous directions to remove). * tet_e is found by identifying the saturation point of classification accuracy on samples generated via a two-stage process (unconditional generation until tet_e, then conditional generation). * kk is adaptively chosen for each sample as the elbow point of the explained variance ratio (EVR) of the singular values of the Jacobian. Experiments show that Canonical Features extracted using CLARID-identified parameters exhibit higher Normalized Mutual Information (NMI) with ground truth labels when clustered, indicating better class separability and compactness compared to features from original samples.

  3. CaDistill (Canonical Distillation): A novel diffusion-based feature distillation paradigm that uses CLAReps.

    • The student network is trained on the full training set.
    • The CDM teacher transfers essential class knowledge using only CLAReps (amounting to ~10% of the training data size).
    • The CaDistill loss function comprises several components:

      • Lcls\mathcal{L}_{cls}: Standard cross-entropy loss for the student on ground-truth labels.
      • Lalign\mathcal{L}_{align}: Encourages student features of training images to be close to student features of same-class Canonical Samples and distant from different-class ones.

        $\mathcal{L}_{align}=-\frac{1}{b} \sum_{i=1}^b \frac{1}{|P_i|} \sum_{j\in P_i} \log \frac{ \exp(\bm{z}_i \cdot \tilde{\bm{z}_j / \tau) }{ \sum_{k=1}^B \exp(\bm{z}_i \cdot \tilde{\bm{z}_k / \tau)}$

      • Lcano\mathcal{L}_{cano}: Encourages student features of same-class Canonical Samples to cluster together and separate from different-class ones.

        $\mathcal{L}_{cano}=-\frac{1}{b} \sum_{i=1}^b \frac{1}{|P_i| - 1} \sum_{j\in P_i, j\neq i} \log \frac{ \exp(\tilde{\bm{z}_i \cdot \tilde{\bm{z}_j / \tau) }{ \sum_{k \ne i} \exp(\tilde{\bm{z}_i \cdot \tilde{\bm{z}_k / \tau)}$

      • Ldist\mathcal{L}_{dist}: Aligns the student's representations of both training images and Canonical Samples with the Canonical Features from the CDM using Centered Kernel Alignment (CKA).

        $\mathcal{L}_{dist}= \lambda_{cka} \log (1- \text{CKA}(\bm{Z}, \mathcal{A})) + (1-\lambda_{cka} ) \log (1- \text{CKA}(\tilde{\bm{Z}, \mathcal{A}))$

* The final loss is a weighted sum: LCaDistill=Lcls+λcs(λcfLalign+(1−λcf)Lcano)+λdistLdist\mathcal{L}_{CaDistill} = \mathcal{L}_{cls} + \lambda_{cs} ( \lambda_{cf} \mathcal{L}_{align} + (1-\lambda_{cf}) \mathcal{L}_{cano} ) + \lambda_{dist} \mathcal{L}_{dist}.

Experimental Validation:

  • Toy Example: CLARID successfully recovers a low-dimensional class manifold in a synthetic dataset.
  • Qualitative Results: On ImageNet with DiT and Stable Diffusion models, CLARID-generated Canonical Samples effectively summarize core class semantics, removing irrelevant context compared to original images or those generated with Classifier-Free Guidance (CFG). CLARID is shown to be generalizable to text-conditioned models and different samplers.
  • Quantitative Results (CaDistill):
    • Experiments on CIFAR-10 (ResNet-18 student) and ImageNet (ResNet-50 student) show that CaDistill significantly improves adversarial robustness (against PGD, CW, APGD-DLR, APGD-CE attacks) and generalization ability (on CIFAR10-C, ImageNet-C, ImageNet-A, ImageNet-ReaL) compared to vanilla training and other diffusion-based methods like DiffAug and a CKA-based DMDistill baseline.
    • Notably, CaDistill achieves these improvements while the CDM teacher only utilizes CLAReps equivalent to 10% of the training data size for knowledge transfer.
    • On the Backgrounds Challenge, models trained with CaDistill show increased reliance on foreground objects and reduced dependence on spurious background cues.
  • Ablation Studies: Validate the necessity of each loss component in CaDistill, the sufficiency of using only 10% CLAReps, and the choice of hyperparameters.

Limitations:

  • CLARID can sometimes select suboptimal projection timesteps (tet_e) or the number of considered extraneous directions (nn).
  • Calculating singular vectors of the CDM's Jacobian is computationally intensive.
  • The effectiveness on very large-scale problems (e.g., ImageNet22K) is yet to be explored.

In conclusion, the paper demonstrates that by identifying and utilizing CLAReps, CDMs can serve not just as image generators but also as compact and interpretable teachers. The proposed CLARID method effectively extracts core class semantics, and the CaDistill framework leverages this to enhance the robustness and generalization of downstream discriminative models.

Whiteboard

Paper to Video (Beta)

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.