Bilevel Joint Unsupervised & Supervised Training
- The paper introduces BL-JUST, a bilevel optimization method that jointly trains unsupervised and supervised objectives to maintain both general representation and task specificity.
- It reformulates the original constrained problem into an unconstrained penalty-based objective, enabling efficient gradient updates and theoretical convergence guarantees.
- Empirical results in ASR, low-light imaging, and medical reconstruction show BL-JUST outperforms traditional pre-training and fine-tuning pipelines.
Bilevel Joint Unsupervised and Supervised Training (BL-JUST) is a bilevel optimization framework designed to integrate unsupervised and supervised objectives into the training of machine learning models, particularly for tasks where labeled data is scarce but abundant unlabeled data is available. Unlike conventional pre-training and fine-tuning (PT+FT) pipelines that treat unsupervised and supervised learning as disjoint stages, BL-JUST jointly optimizes both loss functions such that model parameters remain close to an unsupervised optimum even as they are adapted for the downstream task (Cui et al., 2024, Saif et al., 2024, Ma et al., 2023). This coordinated strategy simultaneously preserves generic structure from self-supervised learning while enabling specialization to the supervised task, yielding improved generalization in various domains including speech recognition, low-light image enhancement, and medical image reconstruction.
1. Bilevel Formulation and Theoretical Foundations
BL-JUST is formally characterized as a bilevel optimization problem. The lower level seeks a local minimum of an unsupervised objective, while the upper level minimizes a supervised loss, constrained such that the shared model parameters (typically the network backbone) reside at or near an unsupervised optimum:
Here, are shared parameters, denotes supervised-head parameters, parameterizes the unsupervised head, is the unsupervised empirical risk on unlabeled data, and the supervised empirical risk on labeled data (Cui et al., 2024). In ASR, for instance, might represent contrastive losses like CPC or InfoNCE, while adopts CTC or RNNT.
This framework generalizes to imaging tasks and other domains where the two-stage PT+FT separation induces negative transfer or sub-optimal adaptation (Ye et al., 2020). Empirically, BL-JUST achieves superior results by allowing supervised learning to “tug” the unsupervised representations without entirely discarding their generality (Saif et al., 2024).
2. Penalty-Based Reformulation and Optimization Algorithms
Directly solving the constrained bilevel problem is computationally expensive. BL-JUST converts it into an unconstrained penalty objective using the value-function gap , with (Cui et al., 2024):
The penalty coefficient is increased monotonically, tightening the lower-level optimality constraint as training progresses. This strategy ensures that model updates do not drift far from an unsupervised local minimum.
Stochastic gradient descent is performed jointly on all parameters, typically via the following update rules per training epoch :
Optionally, an initial self-supervised exploration step on is performed to better approach a local minimum before joint updates (Cui et al., 2024). The penalty schedule () is often linear, with ablation studies confirming that gradual increases are optimal for both ASR and imaging domains (Saif et al., 2024, Ma et al., 2023).
3. Loss Functions and Model Architectures
BL-JUST accommodates diverse architectures and loss combinations, which are typically domain-specific.
Automatic Speech Recognition (ASR) (Cui et al., 2024, Saif et al., 2024)
- Backbone: Conformer or CNN-LSTM
- Unsupervised losses:
- Contrastive Predictive Coding (CPC)
- InfoNCE
- Masked cross-entropy (BEST-RQ)
- Supervised losses:
Low-Light Image Enhancement (Ma et al., 2023)
- Encoder–Decoder with Retinex architecture
- Unsupervised loss: No-reference objectives such as illuminance smoothness and structure-preserving regularization
- Supervised loss: Paired pixelwise error
X-ray CT Reconstruction (Ye et al., 2020)
- MBIR with deep network regularizer
- Unsupervised loss: Statistical priors (edge-preserving, ULTRA transform)
- Supervised loss: Data fidelity to paired high-dose images
The BL-JUST principle is agnostic to architecture; all reported applications share the key design: a backbone optimized jointly for unsupervised and supervised objectives, with head/task-specific parameters adapted as needed.
4. Empirical Performance and Comparison
Extensive experimental results across domains confirm the effectiveness of BL-JUST.
ASR (LibriSpeech, Switchboard, Industrial):
| Method | LibriSpeech (100h/100h) WER | Switchboard (300h) WER | Payload (10kh) Avg WER |
|---|---|---|---|
| Sup CTC | 6.3 / 14.7 | 11.2 | 16.0 |
| PT+FT | 5.8 / 14.0 | 10.9 | 15.7 |
| BL-JUST | 4.9 / 12.2 | 8.2 | 14.8 |
BL-JUST outperforms PT+FT and alternate semi-supervised methods (pseudo-labeling, alternating-optimization) by 0.5–1.5 absolute WER (Cui et al., 2024, Saif et al., 2024).
Low-Light Image Enhancement:
| Method | MIT PSNR | MIT SSIM | MIT LPIPS |
|---|---|---|---|
| BL-JUST | 20.13 | 0.8413 | 0.1799 |
| RBL-JUST (reinforced) | 20.68 | 0.8352 | 0.1631 |
(R)BL-JUST attains top-2 rankings on both paired and unpaired benchmarks, and adaptation requires only a few finetuning steps (Ma et al., 2023).
X-ray CT Reconstruction:
SUPER, a structurally similar alternating bilevel algorithm, exhibits lower RMSE and higher SNR/SSIM compared to state-of-the-art model-based and deep learning baselines (Ye et al., 2020).
5. Convergence and Theoretical Guarantees
The penalty-based bilevel gradient descent (PBGD) underlying BL-JUST offers rigorous convergence behavior. Under mild Lipschitz and Polyak–Łojasiewicz (PL) inequality assumptions on the unsupervised loss, PBGD achieves an iteration complexity to reach an -stationary point—matching the theoretical rate of standard SGD (Cui et al., 2024, Saif et al., 2024).
Final parameter gradients satisfy near-stationarity for both objectives, as evidenced by norm measurements (e.g., , for BL-JUST, compared to in PT+FT), confirming joint optimality (Cui et al., 2024).
Ablation studies indicate that omitting either self-supervised exploration or final supervised finetuning degrades performance (+1.2% WER on LibriSpeech test-clean if exploration is skipped), underscoring the importance of each component in maintaining the balance between unsupervised structure and supervised adaptation (Cui et al., 2024).
6. Generalizations and Domain-Specific Implementations
While initially developed for ASR, BL-JUST generalizes to diverse machine learning problems:
- In low-light image enhancement, BL-JUST uses first-order hypergradient approximations and “reinforcement” steps (meta-level decoder initialization) to enable rapid adaptation to scene changes (Ma et al., 2023).
- In medical imaging, the SUPER framework alternates between solved (or approximately solved) inner unsupervised steps and outer supervised updates, effectively imposing an unsupervised prior and a learnable deep regularizer in the lower level, and optimizing for reconstructive fidelity in the upper (Ye et al., 2020).
The flexibility of the bilevel structure accommodates a wide range of unsupervised and supervised losses and is compatible with first-order and finite-difference hypergradient estimators, as in DARTS-style (Ma et al., 2023) and model-based approaches (Ye et al., 2020).
7. Significance and Ongoing Developments
BL-JUST addresses the fundamental disconnect in conventional semi-supervised learning that arises from decoupled objective functions and staged training. By enforcing matched local optima and enabling a tunable balance between domain-general representations and task-specific adaptation, it delivers empirically superior and theoretically principled results. The method’s convergence rate, architectural agnosticism, and domain-transferability render it a robust paradigm for modern machine learning tasks where data heterogeneity and label scarcity prevail.
Recent work continually extends BL-JUST to new domains, improves computational efficiency via single-loop gradient updates, and explores variants (e.g., reinforced bilevel learning for meta-initialization) to further reduce adaptation time and code complexity. Empirical results consistently validate its advantages over baseline and ablation approaches (Cui et al., 2024, Saif et al., 2024, Ma et al., 2023, Ye et al., 2020).