Papers
Topics
Authors
Recent
Search
2000 character limit reached

SAFE-KD: Risk-Controlled Early Exit for Vision Models

Updated 10 February 2026
  • The framework SAFE-KD is a universal early-exit system that combines hierarchical knowledge distillation with conformal risk control to guarantee statistically bounded selective risk.
  • It attaches intermediate classifier exits to any vision backbone (CNN or ViT) and employs decoupled knowledge distillation alongside consistency regularization for calibrated risk control.
  • Empirical results show up to a 45% reduction in expected inference depth while maintaining or surpassing full-inference accuracy, ensuring efficiency and robustness.

SAFE-KD offers a universal, risk-controlled early-exit framework for modern vision backbones, combining hierarchical knowledge distillation with conformal risk control (CRC) to achieve statistically guaranteed bounds on selective misclassification risk for early-exit architectures. It enables substantial reductions in inference cost via early stopping for "easy" samples, while maintaining user-specified upper bounds on misclassification risk at each exit, calibrated on finite data. SAFE-KD is model-agnostic and deploys on a variety of convolutional (CNN) and transformer-based (ViT) image models (Khazem, 3 Feb 2026).

1. Architecture and Components

SAFE-KD is structured as a lightweight "wrapper" atop any standard vision backbone, supporting both CNNs and Vision Transformers. Its architecture comprises:

  • Base Backbone: Any pretrained or trainable vision model f()f(\cdot) (e.g., ResNet, ConvNeXt, ViT, Swin).
  • Intermediate Exit Heads: At KK select depths, SAFE-KD attaches classifiers producing logits zj(x)RCz_j(x)\in\mathbb{R}^C (j=1,,Kj=1,\ldots,K), class probabilities pj(cx)=softmax(zj(x))cp_j(c|x)=\mathrm{softmax}(z_j(x))_c, and a confidence score (typically, Maximum Softmax Probability MSPj(x)=maxcpj(cx)\text{MSP}_j(x)=\max_c p_j(c|x)). For CNNs, exits use global average pooling and an optional MLP before a fully connected (FC) layer; for ViTs, exits use CLS or mean token pooling, optional LayerNorm, then FC.
  • Teacher Network: An Exponential Moving Average (EMA) of the full model serves as the teacher for knowledge transfer.

This configuration allows SAFE-KD to operate agnostically across architectures, minimally increasing inference overhead.

2. Decoupled Knowledge Distillation and Consistency

Training leverages hierarchical Decoupled Knowledge Distillation (DKD), coupled with deep-to-shallow consistency regularization:

  • DKD Loss: For each exit jj, knowledge is distilled from the teacher logits zT(x)z_T(x) to zj(x)z_j(x) using a split-KL loss:

LDKD(zj,zT,y)=λtar KL(pyTpj,yS)+λnon KL(p\yTpj,\yS)\mathcal{L}_{\mathrm{DKD}}(z_j,z_T,y) = \lambda_{\text{tar}}~\mathrm{KL}(p^T_y \parallel p^S_{j,y}) + \lambda_{\text{non}}~\mathrm{KL}(p^T_{\backslash y} \parallel p^S_{j,\backslash y})

where pyTp^T_y and pj,ySp^S_{j,y} are the (teacher, student) probabilities on the target class yy (ground-truth label), and p\yTp^T_{\backslash y} and pj,\ySp^S_{j,\backslash y} are normalized distributions over all non-target classes.

  • Consistency Regularization: To align intermediate exits with the final head, SAFE-KD adds

Lconsist,j=T2KL(softmax(zK/T)softmax(zj/T))\mathcal{L}_{\mathrm{consist},j}=T^{2}\,\mathrm{KL}\bigl(\mathrm{softmax}(z_K/T)\,\Vert\,\mathrm{softmax}(z_j/T)\bigr)

with weighting β\beta, regularizing posterior agreement between each intermediate exit and the ultimate exit (KK).

  • Total Loss: For weights wjw_j summing to $1$, full training minimizes:

L=j=1Kwj[CE(zj,y)+αLDKD(zj,zT,y)]+j=1K1wjβLconsist,j\mathcal{L} = \sum_{j=1}^K w_j\bigl[\mathrm{CE}(z_j,y) + \alpha \mathcal{L}_{\mathrm{DKD}}(z_j,z_T,y)\bigr] + \sum_{j=1}^{K-1} w_j \beta \mathcal{L}_{\mathrm{consist},j}

where α\alpha is a scaling factor for DKD.

This hierarchical approach increases calibration, depth-to-exit consistency, and maintains high accuracy at all exits.

3. Conformal Risk Control for Early-Exit Thresholds

At inference, SAFE-KD employs Conformal Risk Control (CRC) to set data-driven confidence thresholds at each exit, guaranteeing a user-specified selective risk:

  • Nonconformity Score: rj(x)=1MSPj(x)r_j(x)=1-\text{MSP}_j(x) at exit jj.
  • Acceptance Set: Aj(τ)={x:rj(x)τ}A_j(\tau)=\{x: r_j(x)\leq\tau\}.
  • Selective Misclassification Risk: Rj(τ)=P(y^j(x)yxAj(τ))R_j(\tau)=P(\hat{y}_j(x)\neq y | x \in A_j(\tau)).
  • Threshold Calibration: Using a held-out calibration set {(xi,yi)}i=1n\{(x_i,y_i)\}_{i=1}^n, thresholds τ^j(δ)\widehat{\tau}_j(\delta) are chosen so the conformal upper bound:

R^j(τ)=1+i=1nej,i1{rj,iτ}1+i=1n1{rj,iτ}\widehat{R}_j(\tau) = \frac{1+\sum_{i=1}^n e_{j,i}\,\mathbf{1}\{r_{j,i}\leq\tau\}}{1+\sum_{i=1}^n\mathbf{1}\{r_{j,i}\leq\tau\}}

does not exceed the desired risk level δ\delta. Here ej,i=1{y^j(xi)yi}e_{j,i}=1\{\hat y_j(x_i)\neq y_i\}.

CRC, under the exchangeability assumption, ensures

Rj(τ^j(δ))δ+O(1/n)R_j(\widehat\tau_j(\delta))\leq\delta+\mathcal{O}(1/n)

for each exit, providing finite-sample statistical guarantees.

4. Safe Inference Policy and Practical Deployment

At test time, early exit is governed by the following procedure:

  • Proceed through exits j=1,,K1j=1,\ldots,K-1, checking at each if rj(x)τ^j(δ)r_j(x)\leq\widehat{\tau}_j(\delta).
  • The first such j(x)j^*(x) is used for prediction. If none, inference proceeds to the final exit KK.
  • For every exit jj, the empirical misclassification risk among samples exiting there is guaranteed not to exceed δ\delta (up to sampling correction).

This allows the system designer to select δ\delta according to operational requirements, trading off computational savings for tightly controlled selective risk. All calibration is based on a held-out set.

5. Empirical Evaluation and Results

SAFE-KD has been empirically validated across six architectures (ResNet-50, MobileNetV3-S, EfficientNet-B0, ConvNeXt-T, ViT-S, Swin-T) and multiple image datasets (CIFAR-10/100, STL-10, Pets, Flowers102, Aircraft), delivering:

  • Compute-Accuracy Trade-offs: At δ=5%\delta=5\% risk, SAFE-KD achieves $40$--45%45\% lower expected depth while matching or surpassing full-inference accuracy. Baseline methods (fixed MSP or entropy thresholds) violate the risk constraint, with observed risks $6$--8%8\%.
  • Calibration: SAFE-KD reduces negative log-likelihood (NLL) and expected calibration error (ECE) at all exits.
  • Risk Guarantee: Across sweeps in δ\delta, the observed per-exit risk tracks the theoretical bound Rj(τ^j(δ))δR_j(\widehat\tau_j(\delta)) \lesssim \delta, confirming O(1/n)O(1/n) tightness.
  • Robustness: On CIFAR-10-C corrupted data (severity 3), SAFE-KD attains lower mean corruption error (mCE) at both shallowest and deepest exits compared to comparable multi-exit and DKD-based models, e.g., mCE at exit 1: SAFE-KD $30.2$, DKD $32.8$, MultiExit $35.4$; at final exit: SAFE-KD $20.9$.
  • Ablation Findings: Removing DKD degrades accuracy by 2.1%-2.1\% and forces safer (more conservative) thresholds, increasing average depth. Removing consistency (β=0\beta=0) triggers higher exit-variance, though risk guarantees persist.
  • Example Table (for CIFAR-100, ResNet-50, δ=0.05\delta=0.05):
Method Accuracy Exp. Depth Observed Risk
Fixed MSP 81.5% 0.72 6.8% (Unsafe)
Entropy gate 80.9% 0.65 7.5% (Unsafe)
SAFE-KD (CRC) 82.3% 0.59 4.8% (Safe)

SAFE-KD consistently defines the empirical Pareto frontier for the target risk constraint across tasks.

6. Calibration, Robustness, and Risk Guarantees

SAFE-KD's deployment of CRC uniquely enables it to deliver finite-sample, statistically-tight risk control not attainable with heuristic thresholds. Reliability diagrams confirm alignment of observed risk to the target δ\delta across exit depths. Selective risk curves for a sweep of δ[0.01,0.10]\delta\in[0.01,0.10] show empirical RjR_j at or just under y=xy=x.

For corrupted or hard samples, the framework naturally "defers" to deeper exits, preserving guaranteed selective risk at cost of additional computation. This property, together with out-of-the-box calibration from DKD and consistency, distinguishes SAFE-KD from prior early-exit and distillation methods without such formal risk control.

7. Summary and Broader Impact

SAFE-KD constitutes a general-purpose, modular extension to vision models requiring minimal architectural modification and no retraining of the backbone. Its integration of CRC, DKD, and deep-to-shallow consistency provides user-tunable, quantifiable risk guarantees for early exiting, fine-grained calibration, and enhanced robustness under dataset shift or corruption. The framework supports frequent regression testing and online adaptation to evolving operational requirements, making it especially suitable for resource-constrained or safety-critical deployments (Khazem, 3 Feb 2026).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to SAFE-KD.