Papers
Topics
Authors
Recent
Search
2000 character limit reached

SparseJEPA: Enhancing JEPA with Group Sparsity

Updated 19 January 2026
  • SparseJEPA is a self-supervised framework that adds a group-sparsity penalty to enforce semantically coherent, group-structured latent embeddings.
  • It combines reconstruction error, KL divergence, and a group lasso term to yield improved predictive performance over standard JEPA.
  • Empirical evaluations on benchmarks like CIFAR-100 demonstrate significant gains in accuracy and interpretability compared to dense methods.

SparseJEPA is an extension of Joint Embedding Predictive Architectures (JEPA) that integrates sparse representation learning to enforce semantically coherent, group-structured embeddings. By introducing a group-sparsity penalty in the latent space, SparseJEPA aims to enhance interpretability and efficiency of learned representations while maintaining or exceeding the predictive performance of standard JEPA. The architecture is theoretically justified through its impact on multiinformation in the latent variables and demonstrates empirical superiority on a range of linear-probe transfer tasks (Hartman et al., 22 Apr 2025).

1. Joint Embedding Predictive Architectures Overview

JEPA is a self-supervised learning framework where the objective is to predict a masked "target" embedding from a "context" embedding within a shared latent space. Unlike traditional generative methods that reconstruct raw pixels, JEPA formulates the task in representation space. The procedure involves dividing an input image xx into non-overlapping patches. A Vision Transformer (ViT)-based context encoder EcE_c processes visible patches cc to yield embeddings scs_c, while a separate target encoder EtE_t processes masked patches tt to produce sts_t. A predictor PP uses scs_c, positional encodings, and mask tokens to generate predicted target embeddings s^t\hat{s}_t. The loss function employed is mean squared error (MSE) between predicted and ground-truth embeddings:

LJEPA=1Mi=1MjBiS^yjSyj22L_\mathrm{JEPA} = \frac{1}{M} \sum_{i=1}^{M} \sum_{j\in B_i} \|\hat{S}_{y_j} - S_{y_j}\|^2_2

where MM is the number of target blocks, BiB_i the set of patch indices in block ii, SyjS_{y_j} the true, and S^yj\hat{S}_{y_j} the predicted, embedding for patch jj (Hartman et al., 22 Apr 2025).

2. SparseJEPA: Model Extensions and Sparsity Penalty

SparseJEPA introduces sparsity and latent-grouping in the JEPA framework via a new composite loss:

LSparseJEPA=1Mi=1MjBiS^yjSyj22+βLKL+λg=1Gj=1KW:,j(g)2L_\mathrm{SparseJEPA} = \frac{1}{M}\sum_{i=1}^{M}\sum_{j\in B_i} \|\hat{S}_{y_j} - S_{y_j}\|^2_2 + \beta L_\mathrm{KL} + \lambda \sum_{g=1}^G\sum_{j=1}^K \|W^{(g)}_{:,j}\|_2

  • LKLL_\mathrm{KL} is the KL divergence between the approximate posterior over latent block-embeddings and a prior (as in oi-VAE).
  • β\beta weights the KL term.
  • λ\lambda controls the group-sparsity penalty (group lasso).
  • GG is the number of groups (semantically defined subsets); KK is the number of latent dimensions.
  • W(g)Rd×KW^{(g)} \in \mathbb{R}^{d\times K} is the weight matrix mapping latent dimensions to group gg, with W:,j(g)2\|W^{(g)}_{:,j}\|_2 denoting the l2l_2-norm of column jj (group lasso encourages most columns to be zero for most groups).

The grouping mechanism ensures that patches or features with high mutual information are mapped to overlapping latent dimensions. Each latent dimension is driven, via the group-sparsity penalty, to "select" only a few groups, yielding interpretable and efficient representations (Hartman et al., 22 Apr 2025).

3. Latent-Space Grouping and Training Workflow

The method assumes semantically meaningful groupings among data features (e.g., textures, contours, objects). The main steps are:

  1. Define GG groups aligned with semantic subsets.
  2. Parameterize group membership via W(g)W^{(g)} matrices.
  3. For each batch:
    • Encode context and target patches using Tiny ViT to obtain zRKz \in \mathbb{R}^K.
    • Compute group activations g(g)=W(g)zg^{(g)} = {W^{(g)}}^\top z for g=1,,Gg=1,\ldots,G.
    • Predict target embeddings using only nonzero group activations.
    • Apply the group-sparsity penalty, ensuring each latent dimension is active in only a small subset of groups.

The pseudocode governing each iteration is as follows:

1
2
3
4
5
6
7
8
9
10
11
for each batch x:
  z_c = ViT_context(x_context)
  z_t = ViT_target(x_target)
  for g in 1G:
    a_g = W[g].T @ z_c
  Ŝ_t = Predictor(concat(a_1,,a_G))
  Lrec  = MSE(Ŝ_t, z_t)
  Lkl   = KL(q(z_c) || p(z))
  Lpen  = λ * sum_{g,j} ||W[g][:,j]||
  Loss = Lrec + β * Lkl + Lpen
  backprop(Loss)
This framework enforces that latent variables are shared among features with high semantic correlation, improving both interpretability and downstream task performance (Hartman et al., 22 Apr 2025).

4. Theoretical Foundations: Multiinformation and Grouping

Multiinformation, defined for nn random variables X1,,XnX_1,\ldots,X_n as

I(X1;;Xn)=DKL(p(x1,,xn)i=1np(xi)),I(X_1; \ldots; X_n) = D_{KL}(p(x_1,\ldots,x_n) \Vert \prod_{i=1}^n p(x_i)),

measures the total statistical dependence in a set of variables.

A core result is that grouping via deterministic functions Gj=fj(Xi:iSj)G_j = f_j(X_i: i\in S_j) causes

I(G1;;Gm)I(X1;;Xn)I(G_1;\ldots;G_m) \leq I(X_1; \ldots; X_n)

by the Data Processing Inequality for KL divergence. The reduction is strict when nontrivial dependencies exist among different groups.

Furthermore, if latent variables Z=(Z1,,Zk)Z = (Z_1,\ldots,Z_k) generate observations X=(X1,,Xn)X = (X_1,\ldots,X_n), and ZZ are partitioned into groups S1,,SmS_1,\ldots,S_m, then if the partition reflects the underlying data structure,

I(G1;;Gm)<I(X1;;Xn)I(G_1; \ldots; G_m) < I(X_1; \ldots; X_n)

and

I(Z;G1,,Gm)I(Z;X1,,Xn)I(Z; G_1,\ldots,G_m) \geq I(Z; X_1,\ldots,X_n)

This indicates that the grouped representation retains relevant latent information more compactly than the original data, supporting the use of sparsity-structured grouping penalties to induce compact and semantically aligned latent spaces in JEPA variants (Hartman et al., 22 Apr 2025).

5. Empirical Evaluation and Results

SparseJEPA was evaluated on CIFAR-100 (60,000 images, 100 classes) using a Tiny ViT backbone (patch size 4×44 \times 4, depth 6, hidden dim 384, 6 heads). The optimizer was AdamW with learning rate 3×1043 \times 10^{-4}, weight decay 0.05, batch size 256, over 200 epochs. Sparsity hyperparameters were set as λ=0.01\lambda=0.01, β=1.0\beta=1.0, number of groups G=8G=8, and K=128K=128 latent dimensions. Masking targeted 50% of the image patches per sample.

Linear-probe transfer learning performance (Top-1 accuracy):

Model CIFAR-100 Place205 CLEVR/Count iNat2018
JEPA (dense) 40.01 21.24 59.13 16.52
SparseJEPA 45.40 23.36 62.33 19.63

Compared to standard JEPA, SparseJEPA provides improvements of +5.4 percentage points (pp) on CIFAR-100, +2.1pp on Place205, +3.2pp on CLEVR object counting, and +3.1pp on iNaturalist fine-grained species recognition. SparseJEPA outperforms MAE, data2vec, and CAE linear-probe numbers on CIFAR-100 (all 38–42%) (Hartman et al., 22 Apr 2025).

6. Interpretability, Representational Efficiency, and Future Directions

Qualitative inspection (e.g., heatmaps) reveals that sparsity leads to specialization among latent groups, such as individual groups activating on object contours or textures. The grouping penalty enforces each latent dimension's association with a limited number of semantic groups, supporting traceability of dimensions back to interpretable features.

Empirically, the sparsity constraint reduces redundancy and supports the robustness and generalizability of learned features, aligning with theoretical predictions based on multiinformation reductions.

Future research aims to further leverage the grouping mechanism in the context of object-centric representation learning, such as dynamically inferring object slots via slot-attention frameworks. This extension would facilitate unsupervised discovery of object-level attributes and compositionality within learned representations (Hartman et al., 22 Apr 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to SparseJEPA.