Static-Teacher Asymmetric Latent Training (SALT)
- The paper introduces SALT, a two-stage framework where a fixed, pre-trained teacher provides latent targets that the student model regresses to.
- It decouples teacher and student optimization using an asymmetric loss, achieving effective performance in video SSL, MLIPs, and metric learning.
- The approach simplifies training by freezing the teacher, reducing compute costs while enhancing stability and accuracy in knowledge transfer.
Static-Teacher Asymmetric Latent Training (SALT) is a two-stage, teacher-student framework for knowledge transfer that uses a frozen, static teacher to supervise a student model via latent targets. The SALT methodology decouples teacher and student optimization, avoids momentum or bi-level mechanics, and relies on asymmetric loss objectives where only the student is updated, and the teacher’s latent outputs serve as fixed regression targets. SALT has demonstrated effectiveness in large-scale video self-supervised learning (SSL) (Li et al., 29 Sep 2025), molecular dynamics via machine learning interatomic potentials (MLIPs) (Matin et al., 7 Feb 2025), and asymmetric metric learning for retrieval (Budnik et al., 2020).
1. Conceptual Overview
SALT’s core idea is to replace traditional knowledge distillation or self-distillation regimes (e.g., exponential moving average teachers) with a static, pre-trained teacher, then train a student to regress or match the teacher’s latent representations on masked or partial data. The approach is inherently asymmetric—the teacher is never updated with respect to the student. This removes several regularization and collapse-prevention mechanisms typical in online teacher-student settings.
In SALT, the teacher encodes input (e.g., pixels, atomic environments, or full images) and generates supervisory signals—typically latent vectors or per-node decompositions. The student is then optimized to approximate these latent targets, optionally under challenging or partially observed conditions (e.g., masked inputs, hard mining, or limited capacity).
2. Mathematical Formulation and Objectives
SALT instantiates its main training objective as a sum of task-specific loss terms and a latent transfer (or distillation) term:
- In video SSL (Li et al., 29 Sep 2025), SALT uses two sequential stages:
- Teacher is trained on masked pixel reconstruction:
- Student is trained by minimizing masked latent prediction loss, regressing to the frozen teacher’s latents:
- In MLIPs (Matin et al., 7 Feb 2025), the main transfer is at the atomic energy level:
Here, (the latent target/regression weight) is set much higher than or due to the abundance of atom-wise latent targets.
- In asymmetric metric learning (Budnik et al., 2020), SALT combines triplet (or metric) loss with a regression term onto teacher embeddings:
where
and
The teacher’s parameters are always frozen during student optimization. Only the student’s parameters are updated.
3. Architectures and Implementation Details
SALT is architecture-agnostic but has been implemented with:
- Vision SSL: Vision Transformers (ViT-L/H/g/G) with rotary position encoding, patch sizes of , and tubelet length 2. Decoder for the teacher is a 4-layer transformer (VideoMAEv2 style). Student predictor is a 12-block transformer.
- MLIPs: Hierarchically Interacting Particle Neural Network (HIPNN) with variable width, depth, and tensor basis. Input embeddings include atomic number and radial basis, interaction layers use cutoff filtering, and atom layers are fully connected MLPs.
- Metric Learning: Teacher often uses ResNet-101 or VGG-16 with GeM pooling; students use efficient networks (MobileNetV2, EfficientNet) with matched projections, GeM pooling, and final normalization.
Batch sizes and hyperparameters vary by domain but exhibit commonalities: large batch processing, Adam(W) or SGD optimizers, cosine or stepwise learning rate schedules, and regularization via weight decay.
A summary of example architectures and parameters for video SSL is provided below:
| Model | Encoder Params | Steps | Total FLOPs (×10²¹) |
|---|---|---|---|
| SALT ViT-L | 0.303 B+ | 240k | 1.2 |
| SALT ViT-H | ... | 240k | 1.5 |
| SALT ViT-g | ... | 240k | 1.9 |
| SALT ViT-G | ... | 240k | 2.6 |
4. Training Procedure: Staging and Loss Design
SALT is invariably a two-phase process:
- Teacher Training: Optimize the teacher on a reconstruction or base task with full supervision (pixels, quantum mechanical energies, image class labels, etc.), typically using masking or heavy data augmentation.
- Student Training: Teacher is frozen; student is trained to regress onto the teacher’s latent representations or per-entity decompositions (patch, atom, or image embeddings), possibly combined with other supervision such as metric losses, force matching, or global property prediction.
Crucially, only the student is updated in stage two. There is no feedback or co-evolution of the teacher. Loss weights for latent transfer are typically set large to amplify the effect of the teacher’s information-rich targets, especially when these are more numerous (e.g. atom-wise energies, patch-wise video features).
The general pseudocode for SALT in metric learning is as follows (Budnik et al., 2020):
1 2 3 4 5 6 7 8 9 10 |
Algorithm SALT
Input: Teacher codes {zᵢ^t}, student f_θ, triplet labels, hyperparams {α,λ,η,epochs}
For epoch in 1...E:
1. Mine negatives with sim(f_θ(i), zᵗ(x)) for anchors.
2. Sample B anchors, select corresponding positives/negatives.
3. Compute student embeddings.
4. Compute L_metric, L_transfer, L = L_metric + λ L_transfer.
5. Update θ ← θ - η ∇_θ L.
6. Decay η according to schedule.
Return f_θ. |
5. Masking, Hard Negative Mining, and Data Augmentation
- Masking for SSL: Video SALT applies V-JEPA-style "multiblock" masking—combining contiguous spatial (short block, 15%) and long-range spatial blocks (additional 70%), with masking spanning the full temporal length, yielding ≈90% overall token masking (Li et al., 29 Sep 2025). This masking outperforms random tube and causal methods.
- Atomic Decomposition: In MLIPs, per-atom energies constitute an extensive supply of latent labels, providing implicit data augmentation by leveraging latent generative structure in the teacher (Matin et al., 7 Feb 2025).
- Hard Negative Mining: In metric learning, hard negatives are found by comparing student-encoded queries against the full corpus of precomputed teacher codes, maximizing discriminative signal under the fixed teacher distribution (Budnik et al., 2020).
6. Empirical Performance and Ablations
SALT consistently yields superior or Pareto-optimal compute–accuracy tradeoffs:
- In video SSL, at matched compute, SALT ViT-L achieves 74.9% SSv2 top-1 (vs. 68.2% for V-JEPA 2) with 1.2×10²¹ total FLOPs, and scales favorably to larger models (Li et al., 29 Sep 2025).
- For MLIPs, SALT-trained students are strictly better on force RMSE vs. inference cost compared to controls, with up to 10% improvement at equal model capacity—and reduced memory footprint (Matin et al., 7 Feb 2025).
- Metric learning experiments demonstrate that regression alone (SALT) dominates more complex knowledge-transfer methods under asymmetric evaluation; in some instances, SALT students surpass teacher retrieval accuracy by up to 1–2% mAP (Budnik et al., 2020).
Ablations show that student performance is largely insensitive to teacher size and data mix. The optimal compute allocation overwhelmingly favors student training; for video, allocating 40k steps to the teacher and 200k to the student yields maximum accuracy. Training loss (latent regression objective) is strongly predictive of downstream performance, simplifying checkpoint selection.
7. Applications and Practical Guidance
SALT is readily applicable across domains where latent targets or structured outputs are accessible from a trained teacher:
- Video self-supervised learning with decoupled teacher-student pipelines for "frozen backbone" evaluation.
- MLIPs with per-atom energy decompositions, lowering compute demands for large-scale molecular dynamics.
- Image retrieval via asymmetric metric learning, where inference decouples query and database encoders, facilitating on-device deployment and efficient retrieval.
Practical recommendations include:
- Use small or moderately sized teachers trained primarily to convergence; focus compute resources and optimization on the student phase.
- For node- or patch-level latent targets, set regression weights high relative to global loss terms.
- Freezing the teacher (static weights, no EMA or stop-gradient tricks) consistently stabilizes learning and simplifies training schedules.
- For "Born-Again" experiments where student equals or exceeds teacher size, schedule or anneal the latent regression weight to transition smoothly back to pure supervised loss.
SALT generalizes to any setting with node-wise or entity-wise latent decompositions (e.g., partial charges, local dipoles in molecular systems), and can be integrated with active learning or multi-task objectives.
SALT provides a unified, empirically robust framework for asymmetric latent transfer, simplifying knowledge distillation while advancing efficiency and performance in high-dimensional, structured learning tasks (Li et al., 29 Sep 2025, Matin et al., 7 Feb 2025, Budnik et al., 2020).