- The paper introduces Canonical Latent Representations (CLAReps) that extract essential categorical information from conditional diffusion models for robust representation learning.
- CLARID, the proposed method, projects latent codes to remove extraneous directions, yielding canonical samples with higher class separability as shown by improved NMI scores.
- CaDistill uses CLAReps from only 10% of the training data to enhance adversarial robustness and generalization of downstream models.
This paper introduces a method to extract core categorical information from Conditional Diffusion Models (CDMs) and leverage it for robust representation learning. The authors identify that while CDMs excel at generating diverse and high-fidelity images, they often entangle class-defining features with irrelevant contextual information. This entanglement poses challenges for extracting robust and interpretable representations for downstream tasks.
The core contributions are:
- Canonical Latent Representations (CLAReps): The paper defines CLAReps as latent codes within CDMs whose internal features (Canonical Features) preserve essential categorical information while discarding non-discriminative signals. When decoded, CLAReps produce "Canonical Samples," which are representative images for each class, highlighting core semantics with minimal irrelevant details.
- Canonical Latent Representation Identifier (CLARID): This is the proposed method to identify CLAReps. The process involves:
- Inverting a given training sample x0​ (belonging to class c) to a latent code xte​ at a specific diffusion timestep te​.
- Identifying "extraneous directions" in this latent space. These directions are the right singular vectors of the Jacobian of the CDM's feature extractor fθ,te​(xte​). Moving along these directions changes visual appearance but preserves class identity.
- Obtaining the CLARep xte​~​ by projecting xte​ onto the subspace orthogonal to these k extraneous directions:
xte​~​=(I−Vk​VkT​)xte​, where Vk​ contains the top k right singular vectors.
- Decoding xte​~​ (conditioned on c) to generate a Canonical Sample x0​~​.
- Extracting Canonical Features from the CDM using xte​~​ at a feature extraction timestep tr​.
The paper provides strategies for selecting optimal te​ (projection timestep) and k (number of extraneous directions to remove).
* te​ is found by identifying the saturation point of classification accuracy on samples generated via a two-stage process (unconditional generation until te​, then conditional generation).
* k is adaptively chosen for each sample as the elbow point of the explained variance ratio (EVR) of the singular values of the Jacobian.
Experiments show that Canonical Features extracted using CLARID-identified parameters exhibit higher Normalized Mutual Information (NMI) with ground truth labels when clustered, indicating better class separability and compactness compared to features from original samples.
CaDistill (Canonical Distillation): A novel diffusion-based feature distillation paradigm that uses CLAReps.
* The final loss is a weighted sum: LCaDistill​=Lcls​+λcs​(λcf​Lalign​+(1−λcf​)Lcano​)+λdist​Ldist​.
Experimental Validation:
- Toy Example: CLARID successfully recovers a low-dimensional class manifold in a synthetic dataset.
- Qualitative Results: On ImageNet with DiT and Stable Diffusion models, CLARID-generated Canonical Samples effectively summarize core class semantics, removing irrelevant context compared to original images or those generated with Classifier-Free Guidance (CFG). CLARID is shown to be generalizable to text-conditioned models and different samplers.
- Quantitative Results (CaDistill):
- Experiments on CIFAR-10 (ResNet-18 student) and ImageNet (ResNet-50 student) show that CaDistill significantly improves adversarial robustness (against PGD, CW, APGD-DLR, APGD-CE attacks) and generalization ability (on CIFAR10-C, ImageNet-C, ImageNet-A, ImageNet-ReaL) compared to vanilla training and other diffusion-based methods like DiffAug and a CKA-based DMDistill baseline.
- Notably, CaDistill achieves these improvements while the CDM teacher only utilizes CLAReps equivalent to 10% of the training data size for knowledge transfer.
- On the Backgrounds Challenge, models trained with CaDistill show increased reliance on foreground objects and reduced dependence on spurious background cues.
- Ablation Studies: Validate the necessity of each loss component in CaDistill, the sufficiency of using only 10% CLAReps, and the choice of hyperparameters.
Limitations:
- CLARID can sometimes select suboptimal projection timesteps (te​) or the number of considered extraneous directions (n).
- Calculating singular vectors of the CDM's Jacobian is computationally intensive.
- The effectiveness on very large-scale problems (e.g., ImageNet22K) is yet to be explored.
In conclusion, the paper demonstrates that by identifying and utilizing CLAReps, CDMs can serve not just as image generators but also as compact and interpretable teachers. The proposed CLARID method effectively extracts core class semantics, and the CaDistill framework leverages this to enhance the robustness and generalization of downstream discriminative models.