Meta-Weight-Net: Adaptive Sample Re-weighting
- The paper introduces Meta-Weight-Net, a meta-learning framework that adaptively learns sample re-weighting functions via bilevel optimization to enhance neural network robustness.
- It employs a two-layer MLP as the weighting function, simulating one-step SGD updates to compute meta-gradients from a clean meta-set for improved training.
- Empirical results demonstrate that MW-Net outperforms conventional methods on benchmarks like CIFAR and Clothing1M, despite incurring higher computational cost.
Meta-Weight-Net (MW-Net) is a meta-learning framework designed to adaptively learn sample re-weighting functions for robust training of deep neural networks, particularly under conditions of class imbalance or label noise. Unlike traditional approaches that require manually specified weighting heuristics, MW-Net jointly learns both a base model and an explicit, data-driven weighting function via bilevel optimization, leveraging a small unbiased meta-set to guide the weight learning. The explicit parameterization of the weighting function as a two-layer multilayer perceptron (MLP) allows MW-Net to adapt a broad family of weighting strategies and empirically enhances performance across a variety of benchmarks involving corrupted data distributions (Shu et al., 2019).
1. Bilevel Meta-Learning Formulation
The MW-Net approach is formalized as a bilevel optimization problem. The lower-level (inner) problem adapts the base network's parameters using weighted training losses, while the upper-level (outer) problem updates the weighting function parameters by minimizing meta-loss computed on a clean, unbiased meta-set. Formally:
- Training set: (potentially noisy or imbalanced)
- Meta-set: (), assumed sampled from the target (clean) distribution
- Base network: with parameters
- Per-sample training loss:
- Weighting function: , parameterized by , outputs
Lower-level problem (base training update):
Upper-level problem (meta-weight learning):
This bilevel formulation enables the automatic adaptation of sample weighting to optimize model generalization on the meta-set, thereby mitigating overfitting to biased or mislabeled examples (Shu et al., 2019).
2. Meta-Gradient Derivation and Optimization
In practice, MW-Net employs a one-step approximation of inner-level optimization and backpropagates through this simulated update to compute meta-gradients for the weighting function parameters. For a single training iteration:
- Let be a train minibatch and a meta minibatch.
- Compute per-sample training losses, , for .
- Evaluate weights .
- Simulate one SGD step for the base model: .
- Compute meta-loss on with and backpropagate to update .
The meta-gradient update (using chain rule) is given by:
where and is the meta-gradient evaluated at .
This meta-gradient is used to update :
After the meta update, the real training step uses the updated weights on a new batch (Shu et al., 2019).
3. Weighting Function Architecture
The explicit weighting function (the "Meta-Weight-Net") is implemented as a compact two-layer MLP:
- Input: Current scalar training loss
- Hidden layer: 100 units, ReLU activation
- Output: Scalar in produced via a Sigmoid activation
Weights are oftentimes normalized per minibatch to sum to one. This MLP architecture, with sufficient hidden units, is a universal approximator over continuous weighting functions, enabling MW-Net to mimic or surpass manually-designed weighting schemes for a variety of learning scenarios. Empirically, varying hidden size between 50 and 200 or depth from 2 to 4 has limited effect on performance, indicating robustness to architectural hyperparameter selection (Shu et al., 2019).
4. Algorithm Workflow
A high-level pseudocode of MW-Net is as follows:
1 2 3 4 5 6 7 8 9 10 |
Input: Training set D, Meta-set D^val, learning rates α (classifier), β (meta), initial w^0, θ^0
for t = 0..T-1 do
1. Sample train minibatch B from D, meta minibatch B^val from D^val
2. Compute per-sample losses l_i^train for i in B
3. Compute weights v_i = f_θ^t(l_i^train)
4. Simulate SGD: w' = w^t - α * (1/|B|) * sum_i v_i * grad_w l_i^train
5. Compute meta-loss on B^val at w'
6. Update θ using meta-gradient: θ^{t+1} = θ^t - β * grad_θ meta-loss
7. Recompute weights v_i = f_{θ^{t+1}}(l_i^train) and update w: w^{t+1} = w^t - α * (1/|B|) * sum_i v_i * grad_w l_i^train
end for |
This workflow involves, per iteration, simulating one update to the base model to compute the effect of weight adjustments, followed by the meta-update and actual training step (Shu et al., 2019).
5. Experimental Results
MW-Net has been empirically evaluated on multiple benchmarks, demonstrating superior performance compared to alternative reweighting techniques in class-imbalanced and noisy-label settings:
Class Imbalance (Long-tailed CIFAR-10 / CIFAR-100)
Test accuracy on ResNet-32 (mean over 5 runs, imbalance factor 100):
| Method | CIFAR-10 | CIFAR-100 |
|---|---|---|
| BaseModel | 70.36 | 38.32 |
| Focal Loss | 70.38 | 38.41 |
| Class-Balanced | 74.57 | 39.60 |
| L2RW | 74.16 | 40.23 |
| MW-Net (ours) | 75.21 | 42.09 |
Noisy Label Benchmarks
- Uniform 40% noise, CIFAR-10 (WRN-28-10): MW-Net 89.3%, GLC 88.3%, L2RW 86.9%, Co-teaching 74.8%, BaseModel 68.1
- Uniform 60% noise, CIFAR-100: MW-Net 58.8%, GLC 50.8%, L2RW 48.2%, Co-teaching 35.7%, BaseModel 30.9
Real-World Data (Clothing1M)
- CrossEntropy 68.9%, S-adaptation 70.4%, Joint Opt 72.2%, MLNT 73.5%, MW-Net 73.7%
Furthermore, ablation experiments reveal that the learned recovers known weighting forms in conventional settings: monotonic increasing for class imbalance (Focal), decreasing for noisy labels (SPL), and non-monotonic for more complex cases (Clothing1M) (Shu et al., 2019).
6. Implementation Complexity and Limitations
Each MW-Net iteration requires approximately triple the computational cost of ordinary SGD, due to (i) the simulated base model update, (ii) meta-update of via backpropagation through the simulated step, and (iii) the actual training update with newly estimated weights. Meta-batch sizes are typically set small (≤ training batch size) to offset this cost. Key hyperparameters include the base model learning rate, meta learning rate (smaller for class-imbalance than noisy settings), and meta-set size (e.g., 10 per class for class imbalance, up to 7K for Clothes1M).
A limitation of MW-Net is its reliance on a small, clean meta-set: if these samples are not representative or are themselves biased, sample weighting may not generalize optimally. Mitigation strategies include using small meta-batches or asynchronous meta-updates. Convergence to stationary points of both meta-loss and training loss is guaranteed under mild Lipschitz and smoothness assumptions, and empirical results indicate consistent loss decrease and robust convergence (Shu et al., 2019).
7. Position in the Landscape of Meta-Weighting Methods
MW-Net is distinct from non-parametric approaches or methods with hand-designed weighting heuristics—such as Focal Loss, Self-Paced Learning (SPL), and L2RW—by virtue of its data-adaptive, function-learning paradigm driven by a meta-objective. Its MLP-based explicit weight function is more expressive than the scalar or monotonic mappings of prior techniques, and is learned via direct differentiation through meta-loss on a hold-out set. MW-Net has influenced subsequent meta-weighting methods where the assignment of sample weights is adapted through end-to-end learning in a differentiable, model-agnostic manner (Shu et al., 2019).