- The paper presents a probabilistic framework that generalizes contrastive learning by tilting the product of marginal distributions and minimizing divergences between true and learned conditionals.
- It compares cosine similarity and L2-distance tilting methods, showing how different loss functions affect the matching of conditional means and covariances in Gaussian models.
- The study demonstrates practical applications in crossmodal retrieval, classification, and Lagrangian data assimilation, validated by both theoretical analysis and numerical experiments.
This paper, "A Mathematical Perspective On Contrastive Learning" (2505.24134), provides a mathematical framework for understanding bimodal contrastive learning, interpreting it as a method for learning a joint probability distribution over two modalities by tilting the product of their marginal distributions. This probabilistic perspective allows the authors to generalize existing contrastive learning methods and analyze their properties, particularly in the tractable setting of multivariate Gaussian distributions. The paper emphasizes practical implications, analyzing how different formulations impact downstream tasks like retrieval and classification.
The core setup involves two data modalities, u and v, drawn in pairs from a joint distribution μ(u,v). Contrastive learning aims to find encoders gu and gv that map u and v into a common latent space Rne, typically with ne much smaller than the original data dimensions. Standard approaches, like CLIP, normalize the encoder outputs to the unit sphere, using the cosine similarity ⟨Eu(u),Ev(v)⟩ as an alignment metric. The training objective, often a form of InfoNCE or cross-entropy loss, encourages high similarity for paired data and low similarity for unpaired data (negative samples from shuffled batches). The paper shows that the population limit of the standard contrastive loss minimizes the sum of KL divergences between the true conditional distributions μu∣v,μv∣u and the learned conditional distributions νu∣v,νv∣u derived from a parameterized joint distribution ν defined by an exponential tilting of the product marginals μu⊗μv.
The probabilistic framework leads to two main classes of generalizations:
- Generalized Probabilistic Loss Functions: Instead of minimizing the sum of KL divergences between conditionals, one could:
- Minimize a weighted sum of divergences for the conditionals, λuD(μu∣v∣∣νu∣v)+λvD(μv∣u∣∣νv∣u). Setting λv=0 or λu=0 focuses the learning on matching only one conditional, which is relevant for asymmetric tasks like classification.
- Minimize the divergence between the true joint distribution μ and the learned joint distribution ν, D(μ∣∣ν). The paper shows that for KL divergence, this leads to an objective that is computationally advantageous as it only requires one batch from the joint and one batch from the product of marginals, compared to the per-sample negative batches needed for the conditional loss. The KL joint loss is shown to provide an upper bound for the KL conditional loss.
- Use alternative divergences or metrics, like Maximum Mean Discrepancy (MMD), which are shown to be actionable with empirical data.
- Generalized Tilting: The learned joint distribution ν is defined by a density ρ(u,v;θ) relative to μu⊗μv. The standard approach uses ρ(u,v;θ)∝exp(⟨Eu(u),Ev(v)⟩/τ). Generalizations can involve different functional forms for ρ, for instance:
- Using unnormalized encoders gu,gv in the exponential tilting: ρ(u,v;θ)∝exp(⟨gu(u),gv(v)⟩/τ).
- Using the L2 distance between latent vectors in the exponential tilting: ρ(u,v;θ)∝exp(−2τ1∣gu(u)−gv(v)∣2).
The paper analyzes these generalizations in detail for the case where μ is a multivariate Gaussian distribution and the encoders are linear functions (gu(u)=Gu,gv(v)=Hv).
- Cosine Distance + Conditional Loss (Standard CLIP): Using the original exponential tilting with linear encoders, the paper shows that minimizing the conditional loss results in learning a matrix A=G⊤H that matches the conditional means of the Gaussian distribution (e.g., E[u∣v] for μ) but not the conditional covariances. The learned conditional covariances are fixed to the marginal covariances of μu,μv, which are generally larger than the true conditional covariances unless u and v are independent. When restricted to low-rank matrices (corresponding to ne<min(nu,nv)), the solution is related to the low-rank approximation of a matrix derived from the covariances.
- Positive Quadratic Form + Conditional Loss: Using the L2-distance tilting, ρ(u,v;θ)∝exp(−21∣Gu−Hv∣2), allows the model to learn matrices A=G⊤H and B=G⊤G (and C=H⊤H). Minimizing a one-sided conditional loss (e.g., matching only μu∣v) in this setting allows the model to exactly match both the conditional mean and covariance of that specific conditional distribution, provided ne is sufficiently large. The optimization for the rank-constrained case is also derived.
- Cosine Distance + Joint Loss: Using the original exponential tilting but minimizing the joint loss D(μ∣∣ν), the paper shows that the optimal matrix A=G⊤H is obtained by applying a singular value shrinkage function to the singular values of the matrix optimized by the conditional loss. This formulation results in a learned joint distribution whose marginal distributions are closer to the true marginal distributions of μ compared to the distribution learned by minimizing the conditional loss.
The practical applications discussed include:
- Crossmodal Retrieval: Given an instance of one modality (e.g., text prompt v), find the most similar instances of the other modality (e.g., images u) in a dataset. This is framed as finding the mode of the empirical conditional distribution νu∣vN, which corresponds to maximizing the cosine similarity ⟨Eu(ui),Ev(v)⟩ over the dataset images ui.
- Crossmodal Classification: Given an instance of one modality (e.g., image u), assign it a label from a predefined set (e.g., text labels vi). This is framed as finding the mode of the empirical conditional distribution νv∣uK over the label set, maximizing ⟨Eu(u),Ev(vi)⟩. The paper shows how standard image classification networks (like LeNet on MNIST) can be interpreted within this framework using unnormalized image encoders and one-hot encoded labels with a one-sided conditional loss. The framework also supports fine-tuning to adapt pretrained models to specific classification tasks.
- Lagrangian Data Assimilation: This is presented as a novel application in science and engineering. The task is to recover an Eulerian velocity field (represented by coefficients of a potential, u) from Lagrangian trajectories of particles in the flow (v). The authors train a contrastive model with a transformer-based encoder for trajectories and a fixed encoder for potential coefficients. Experiments show that this purely data-driven approach successfully learns embeddings that enable accurate retrieval of potentials from trajectories and vice-versa.
Numerical experiments on Gaussian data validate the theoretical findings regarding mean/covariance matching and the properties of different loss functions. Experiments on MNIST demonstrate how different loss functions (one-sided vs. two-sided) affect classification accuracy vs. the diversity of images sampled from the learned conditional distribution. The Lagrangian data assimilation experiment highlights the potential of applying contrastive learning methods to scientific problems involving disparate data modalities.
In summary, the paper provides a principled probabilistic foundation for contrastive learning, generalizes existing methods via novel loss functions and tiltings, offers analytical insights through Gaussian models, and demonstrates practical applicability to traditional AI tasks and novel scientific domains. The focus on the learned joint and conditional distributions provides a valuable perspective for understanding the capabilities and limitations of different contrastive learning formulations.