Papers
Topics
Authors
Recent
Search
2000 character limit reached

Meta-Weight-Net: Adaptive Sample Re-weighting

Updated 22 January 2026
  • The paper introduces Meta-Weight-Net, a meta-learning framework that adaptively learns sample re-weighting functions via bilevel optimization to enhance neural network robustness.
  • It employs a two-layer MLP as the weighting function, simulating one-step SGD updates to compute meta-gradients from a clean meta-set for improved training.
  • Empirical results demonstrate that MW-Net outperforms conventional methods on benchmarks like CIFAR and Clothing1M, despite incurring higher computational cost.

Meta-Weight-Net (MW-Net) is a meta-learning framework designed to adaptively learn sample re-weighting functions for robust training of deep neural networks, particularly under conditions of class imbalance or label noise. Unlike traditional approaches that require manually specified weighting heuristics, MW-Net jointly learns both a base model and an explicit, data-driven weighting function via bilevel optimization, leveraging a small unbiased meta-set to guide the weight learning. The explicit parameterization of the weighting function as a two-layer multilayer perceptron (MLP) allows MW-Net to adapt a broad family of weighting strategies and empirically enhances performance across a variety of benchmarks involving corrupted data distributions (Shu et al., 2019).

1. Bilevel Meta-Learning Formulation

The MW-Net approach is formalized as a bilevel optimization problem. The lower-level (inner) problem adapts the base network's parameters using weighted training losses, while the upper-level (outer) problem updates the weighting function parameters by minimizing meta-loss computed on a clean, unbiased meta-set. Formally:

  • Training set: D={(xi,yi)}i=1N\mathcal{D} = \{(x_i, y_i)\}_{i=1}^N (potentially noisy or imbalanced)
  • Meta-set: Dval={(xjval,yjval)}j=1M\mathcal{D}^{val} = \{(x_j^{val}, y_j^{val})\}_{j=1}^M (MNM \ll N), assumed sampled from the target (clean) distribution
  • Base network: g(x;w)g(x; w) with parameters ww
  • Per-sample training loss: itrain(w)(g(xi;w),yi)\ell_i^{train}(w) \triangleq \ell(g(x_i; w), y_i)
  • Weighting function: fθ:R+[0,1]f_\theta: \mathbb{R}_+ \to [0,1], parameterized by θ\theta, outputs wi=fθ(itrain)w_i = f_\theta(\ell_i^{train})

Lower-level problem (base training update):

w(θ)=argminw(1Ni=1Nfθ(itrain(w))itrain(w))w^* (\theta) = \arg\min_w \left( \frac{1}{N} \sum_{i=1}^N f_\theta(\ell_i^{train}(w)) \cdot \ell_i^{train}(w) \right)

Upper-level problem (meta-weight learning):

θ=argminθ(1Mj=1M(g(xjval;w(θ)),yjval))\theta^* = \arg\min_\theta \left( \frac{1}{M} \sum_{j=1}^M \ell(g(x_j^{val}; w^*(\theta)), y_j^{val}) \right)

This bilevel formulation enables the automatic adaptation of sample weighting to optimize model generalization on the meta-set, thereby mitigating overfitting to biased or mislabeled examples (Shu et al., 2019).

2. Meta-Gradient Derivation and Optimization

In practice, MW-Net employs a one-step approximation of inner-level optimization and backpropagates through this simulated update to compute meta-gradients for the weighting function parameters. For a single training iteration:

  • Let B\mathcal{B} be a train minibatch and Bval\mathcal{B}^{val} a meta minibatch.
  • Compute per-sample training losses, litrainl_i^{train}, for iBi \in \mathcal{B}.
  • Evaluate weights vi=fθ(litrain)v_i = f_\theta(l_i^{train}).
  • Simulate one SGD step for the base model: w=wtα1BiBviwitrainw' = w^t - \alpha \frac{1}{|\mathcal{B}|} \sum_{i\in \mathcal{B}} v_i \nabla_w \ell_i^{train}.
  • Compute meta-loss on Bval\mathcal{B}^{val} with ww' and backpropagate to update θ\theta.

The meta-gradient update (using chain rule) is given by:

θLval=αniB[witrainTwLval]fθ(itrain)θ\nabla_\theta L^{val} = - \frac{\alpha}{n} \sum_{i\in\mathcal{B}} \left[ \nabla_w \ell_i^{train}{}^T \nabla_{w'} L^{val} \right] \frac{\partial f_\theta(\ell_i^{train})}{\partial \theta}

where n=Bn = |\mathcal{B}| and wLval\nabla_{w'} L^{val} is the meta-gradient evaluated at ww'.

This meta-gradient is used to update θ\theta:

θt+1θtβθLval(w)\theta^{t+1} \leftarrow \theta^t - \beta \cdot \nabla_\theta L^{val}(w')

After the meta update, the real training step uses the updated weights on a new batch (Shu et al., 2019).

3. Weighting Function Architecture

The explicit weighting function fθf_\theta (the "Meta-Weight-Net") is implemented as a compact two-layer MLP:

  • Input: Current scalar training loss R+\ell \in \mathbb{R}_+
  • Hidden layer: 100 units, ReLU activation
  • Output: Scalar in (0,1)(0,1) produced via a Sigmoid activation

Weights are oftentimes normalized per minibatch to sum to one. This MLP architecture, with sufficient hidden units, is a universal approximator over continuous weighting functions, enabling MW-Net to mimic or surpass manually-designed weighting schemes for a variety of learning scenarios. Empirically, varying hidden size between 50 and 200 or depth from 2 to 4 has limited effect on performance, indicating robustness to architectural hyperparameter selection (Shu et al., 2019).

4. Algorithm Workflow

A high-level pseudocode of MW-Net is as follows:

1
2
3
4
5
6
7
8
9
10
Input: Training set D, Meta-set D^val, learning rates α (classifier), β (meta), initial w^0, θ^0
for t = 0..T-1 do
   1. Sample train minibatch B from D, meta minibatch B^val from D^val
   2. Compute per-sample losses l_i^train for i in B
   3. Compute weights v_i = f_θ^t(l_i^train)
   4. Simulate SGD: w' = w^t - α * (1/|B|) * sum_i v_i * grad_w l_i^train
   5. Compute meta-loss on B^val at w'
   6. Update θ using meta-gradient: θ^{t+1} = θ^t - β * grad_θ meta-loss
   7. Recompute weights v_i = f_{θ^{t+1}}(l_i^train) and update w: w^{t+1} = w^t - α * (1/|B|) * sum_i v_i * grad_w l_i^train
end for

This workflow involves, per iteration, simulating one update to the base model to compute the effect of weight adjustments, followed by the meta-update and actual training step (Shu et al., 2019).

5. Experimental Results

MW-Net has been empirically evaluated on multiple benchmarks, demonstrating superior performance compared to alternative reweighting techniques in class-imbalanced and noisy-label settings:

Class Imbalance (Long-tailed CIFAR-10 / CIFAR-100)

Test accuracy on ResNet-32 (mean over 5 runs, imbalance factor 100):

Method CIFAR-10 CIFAR-100
BaseModel 70.36 38.32
Focal Loss 70.38 38.41
Class-Balanced 74.57 39.60
L2RW 74.16 40.23
MW-Net (ours) 75.21 42.09

Noisy Label Benchmarks

  • Uniform 40% noise, CIFAR-10 (WRN-28-10): MW-Net 89.3%, GLC 88.3%, L2RW 86.9%, Co-teaching 74.8%, BaseModel 68.1
  • Uniform 60% noise, CIFAR-100: MW-Net 58.8%, GLC 50.8%, L2RW 48.2%, Co-teaching 35.7%, BaseModel 30.9

Real-World Data (Clothing1M)

  • CrossEntropy 68.9%, S-adaptation 70.4%, Joint Opt 72.2%, MLNT 73.5%, MW-Net 73.7%

Furthermore, ablation experiments reveal that the learned fθf_\theta recovers known weighting forms in conventional settings: monotonic increasing for class imbalance (Focal), decreasing for noisy labels (SPL), and non-monotonic for more complex cases (Clothing1M) (Shu et al., 2019).

6. Implementation Complexity and Limitations

Each MW-Net iteration requires approximately triple the computational cost of ordinary SGD, due to (i) the simulated base model update, (ii) meta-update of θ\theta via backpropagation through the simulated step, and (iii) the actual training update with newly estimated weights. Meta-batch sizes are typically set small (≤ training batch size) to offset this cost. Key hyperparameters include the base model learning rate, meta learning rate (smaller for class-imbalance than noisy settings), and meta-set size (e.g., 10 per class for class imbalance, up to 7K for Clothes1M).

A limitation of MW-Net is its reliance on a small, clean meta-set: if these samples are not representative or are themselves biased, sample weighting may not generalize optimally. Mitigation strategies include using small meta-batches or asynchronous meta-updates. Convergence to stationary points of both meta-loss and training loss is guaranteed under mild Lipschitz and smoothness assumptions, and empirical results indicate consistent loss decrease and robust convergence (Shu et al., 2019).

7. Position in the Landscape of Meta-Weighting Methods

MW-Net is distinct from non-parametric approaches or methods with hand-designed weighting heuristics—such as Focal Loss, Self-Paced Learning (SPL), and L2RW—by virtue of its data-adaptive, function-learning paradigm driven by a meta-objective. Its MLP-based explicit weight function is more expressive than the scalar or monotonic mappings of prior techniques, and is learned via direct differentiation through meta-loss on a hold-out set. MW-Net has influenced subsequent meta-weighting methods where the assignment of sample weights is adapted through end-to-end learning in a differentiable, model-agnostic manner (Shu et al., 2019).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Meta-Weight-Net.