Nested MIL with Attention
- The paper introduces a nested MIL framework that employs dedicated attention mechanisms at every hierarchy level to enhance prediction accuracy and interpretability.
- NMIA is a hierarchical model that uses multi-level feature embedding and aggregation to capture complex dependencies in weakly supervised, bag-of-bags data.
- Empirical evaluations demonstrate that NMIA outperforms traditional MIL approaches, particularly in tasks requiring structured latent label inference and rule-based aggregation.
Nested Multiple Instance Learning with Attention (NMIA) extends the canonical Multiple Instance Learning (MIL) paradigm to address weakly supervised problems with complex hierarchical structures, where only bag-of-bags labels are available and neither instance nor inner-bag labels are observed. NMIA introduces levels of bag nesting and employs dedicated attention mechanisms at each level. This framework enables not only accurate prediction of the outermost bag labels but also interpretable soft predictions of latent labels at lower levels. The original model formulation and empirical analysis are detailed in "Nested Multiple Instance Learning with Attention Mechanisms" (Fuster et al., 2021).
1. Hierarchical Weak Supervision and Formal Setup
NMIA formalizes a setting where only the label of a single outermost bag is observed, but the data structure is intrinsically hierarchical:
- Level 1 (Innermost): Instances , grouped into inner-bags.
- Levels 2...J–1: Each level comprises bags of elements from level .
- Level J (Outermost): The top-level bag contains inner-bags , for .
Notation:
- : -th element of -th bag at level ( for instance; for embedding of sub-bag if ).
- : -th bag at level , elements.
- .
- : latent label of (not observed).
- Under standard MIL (), .
This nested organization generalizes MIL such that models can capture complex dependencies, like grouping similar instances or enforcing relational bag rules.
2. Model Architecture and Attention Mechanisms
NMIA employs a multi-tiered process for representation and aggregation, parameterized as follows:
2.1 Instance-level Feature Embedding
Each raw instance is embedded:
where is typically a CNN or MLP.
2.2 Attention from Instance to Inner-bag
Attention scores for each instance in its inner-bag:
with , .
A gated-attention variant is also considered:
with , element-wise, sigmoid.
2.3 Inner-Bag Representation Aggregation
Weighted sum for each inner-bag:
2.4 Attention from Inner-bag to Outer-bag
Aggregation to the outer-bag:
with , .
Final bag-of-bags embedding:
2.5 Classification Head
Prediction via:
where is the predicted probability.
3. Training Objective and Optimization
The model is trained end-to-end with combined parameters , minimizing binary cross-entropy on outer-bag labels and optional regularization:
Early stopping is typically applied using a held-out validation set.
The full process from instance embedding to nested attention aggregation is differentiable, amenable to optimization by SGD or Adam.
4. Latent Label Prediction via Hierarchical Attention
Although supervision is available only at the outer bag level, NMIA leverages nested attention for latent label inference:
- Instance-level score : Measures likelihood that instance is positive within its inner-bag; thresholding enables latent positive assignment .
- Inner-bag score : Indicates inner-bag contribution to the positive outer-label; thresholding yields latent positive inner-bag assignment .
This nested inference enables partial recovery of latent structure, as shown in medical whole-slide imaging (WSI) examples where attention highlights candidate lesions and regions.
5. Computational Workflow
The NMIA training/inference procedure is directly expressed in the following pseudocode:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
Given nested dataset { X_i, y_i } for i = 1...N:
Initialize θ
repeat for epoch = 1...MaxEpochs:
for each minibatch of outer-bags {X_i, y_i}:
for each bag X_i:
# Level 1: embed instances
for k=1...K1, l=1...L_{1,k}:
h_{1,k,l} ← f(x_{1,k,l}; θ_f)
# Attention + aggregation at level 1
for each inner-bag k:
a_{1,k,l} ← exp(wᵀ h_{1,k,l} + b)
α_{1,k,l} ← a_{1,k,l}/sum_l a_{1,k,l}
m_{1,k} ← sum_l α_{1,k,l} h_{1,k,l}
# Attention + aggregation at level 2
for k=1...K1:
b_k ← exp(vᵀ m_{1,k} + c)
β_k ← b_k / sum_k b_k
M ← sum_k β_k m_{1,k}
ŷ_i ← Θ_c(M; θ_c)
Compute loss L = sum_i [–y_i log ŷ_i – (1–y_i) log(1–ŷ_i)] + λ∥θ∥²
Back-propagate ∇_θ L, update θ by SGD/Adam
Validate on held-out bags; apply early stopping |
At inference, the same forward pass provides and attention maps , supporting both outer prediction and interpretability for inner structure via attention thresholding.
6. Empirical Evaluation and Comparative Results
NMIA was evaluated on two-level (MNIST, PCAM) and three-level (MNIST "odd-only" rule) benchmarks, compared with alternative MIL architectures:
| Dataset/Experiment | MI | MIA | NMI | NMIA |
|---|---|---|---|---|
| MNIST Exp1 (single-instance→bag) | 0.929 | 0.957 | 0.923 | 0.959 |
| MNIST Exp2 (≥2 positives in same inner-bag) | 0.345 | 0.472 | 0.855 | 0.921 |
| MNIST Exp3 (3-level "odd-only" rule) | N/A | N/A | 0.556 | 0.836 |
| PCAM Exp1 (standard MIL) | 0.957 | 0.973 | 0.964 | 0.978 |
| PCAM Exp2 (≥2 metastatic patches/region) | 0.290 | 0.286 | 0.700 | 0.734 |
- In easy tasks (Exp1), all models perform well, with NMIA slightly outperforming alternatives.
- For rule-based tasks requiring the grouping of positives (Exp2), conventional MI/MIA architectures fail, while NMI and NMIA model the required relations, with NMIA achieving superior F1.
- The three-level hierarchy (Exp3) demonstrates only the NMIA architecture's capacity to learn complex hierarchical rules (e.g., aggregating presence/absence across nested levels).
- Qualitative attention visualizations confirm that scores highlight salient instances (“9” digits, metastatic regions) and scores pinpoint relevant inner-bags.
A plausible implication is that NMIA enhances interpretability for nested weakly-supervised problems and is advantageous where ground-truth is available only at the highest level, but models or applications demand finer-grained insight into hierarchical structure.
7. Connections to Related Methodologies
NMIA generalizes attention-based MIL architectures via explicit hierarchical nesting, combining soft attention for instance selection with multi-level aggregation. This approach is especially pertinent for domains such as computational pathology, vision, and any application where entities are naturally grouped and only coarse labels are available. The nesting and attention extensibility allow NMIA to subsume previous MIL variants (mean aggregation, single-level attention) and outperform them in tasks necessitating hierarchical inference (Fuster et al., 2021).
The framework's broad applicability suggests future directions in further hierarchy modeling, explainable machine learning, and adaptation to domains with complex nested-label structures.