Papers
Topics
Authors
Recent
Search
2000 character limit reached

Distillation Scaling Laws

Published 12 Feb 2025 in cs.LG, cs.AI, cs.CL, and stat.ML | (2502.08606v1)

Abstract: We provide a distillation scaling law that estimates distilled model performance based on a compute budget and its allocation between the student and teacher. Our findings reduce the risks associated with using distillation at scale; compute allocation for both the teacher and student models can now be done to maximize student performance. We provide compute optimal distillation recipes for when 1) a teacher exists, or 2) a teacher needs training. If many students are to be distilled, or a teacher already exists, distillation outperforms supervised pretraining until a compute level which grows predictably with student size. If one student is to be distilled and a teacher also needs training, supervised learning should be done instead. Additionally, we provide insights across our large scale study of distillation, which increase our understanding of distillation and inform experimental design.

Summary

  • The paper introduces a scaling law that predicts a distilled student model's cross-entropy loss as a function of teacher loss, student parameters, and distillation tokens.
  • It demonstrates how compute-optimal resource allocation between teacher pretraining, inference, and student training can minimize loss under fixed FLOP budgets.
  • The study offers practical insights for designing distillation experiments, emphasizing overtraining, capacity gap awareness, and hyperparameter tuning for robust performance.

The paper "Distillation Scaling Laws" (2502.08606) introduces a predictive framework for estimating the performance of a distilled student model based on the compute budget and its allocation between the teacher and student components. This framework aims to mitigate risks associated with large-scale distillation efforts by enabling compute-optimal resource allocation.

The Distillation Scaling Law Formulation

The core contribution is a scaling law that predicts the student's cross-entropy loss (LSL_S) as a function of its non-embedding parameters (NSN_S), the number of tokens used for distillation (DSD_S), and the teacher's cross-entropy loss (LTL_T). The proposed functional form is given by Equation 9 in the paper:

$L_S(N_S, D_S, L_T) = L_T + \frac{1}{L_T^{c_0} \left(1+\left(\frac{L_T}{\widetilde{L}_S d_1}\right)^{1/{f_1}\right)^{-c_1f_1}} \times \left(\frac{A'}{N_S^{\alpha'} + \frac{B'}{D_S^{\beta'}}}\right)^{\gamma'}$

Key characteristics of this formulation include:

  • Teacher Loss Dependence: Student performance (LSL_S) is primarily determined by the teacher's performance (LT=L(NT,DT)L_T = L(N_T, D_T)), not the specific combination of teacher parameters (NTN_T) and training tokens (DTD_T) that produced that loss. This simplifies the teacher selection problem.
  • Student Capability Term: The term (A′/NSα′+B′/DSβ′)γ′(A'/N_S^{\alpha'} + B'/D_S^{\beta'})^{\gamma'} mirrors the structure of supervised scaling laws (e.g., Chinchilla (Hoffmann et al., 2022)), indicating that the student's intrinsic capacity improves with more parameters (NSN_S) and data (DSD_S).
  • Capacity Gap Modeling: The term involving LTL_T and L~S\widetilde{L}_S incorporates a broken power law structure. Here, L~S=L(NS,DS)\widetilde{L}_S = L(N_S, D_S) represents the loss the student would achieve via supervised training on the same data volume DSD_S. This term captures the "capacity gap" phenomenon: improving the teacher (lower LTL_T) generally improves the student (lower LSL_S), but only up to a point. If the teacher becomes significantly more capable than the student's learning capacity (characterized by the ratio LT/L~SL_T / \widetilde{L}_S crossing a threshold d1d_1), further teacher improvements can degrade student performance. The gap is attributed to differences in learning capacity, not merely model size.
  • Empirical Validation: The law demonstrates a good fit to experimental data, typically achieving relative errors below 1%, and successfully extrapolates to unseen configurations. It also remains consistent with supervised scaling laws in the infinite data limit.

Compute Budget Allocation and Optimization

The total computational cost (FLOPs) associated with distillation is approximated by Equation 10:

FLOPs≈3F(NS)DS+F(NT)(δTLgtDS+δTPre3DT)\mathrm{FLOPs} \approx 3F(N_S)D_S + F(N_T)(\delta_T^{\mathrm{Lgt}}D_S + \delta_T^{\mathrm{Pre}}3D_T)

Here:

  • F(N)F(N) is the FLOPs per token for a forward pass of a model with NN non-embedding parameters. The paper recommends using a refined estimate F(N)≈2N(1+c1N−1/3+c2N−2/3)F(N) \approx 2N(1+c_1N^{-1/3}+c_2N^{-2/3}) for fixed-aspect ratio models, which is particularly relevant for accurately estimating costs of smaller models with large context windows.
  • 3F(NS)DS3F(N_S)D_S represents the student training cost (forward + backward pass).
  • F(NT)δTLgtDSF(N_T)\delta_T^{\mathrm{Lgt}}D_S is the cost of teacher inference to generate logits for distillation, incurred if δTLgt=1\delta_T^{\mathrm{Lgt}=1}.
  • 3F(NT)δTPreDT3F(N_T)\delta_T^{\mathrm{Pre}}D_T is the cost of pretraining the teacher model, incurred if δTPre=1\delta_T^{\mathrm{Pre}=1}.

The flags δTLgt\delta_T^{\mathrm{Lgt}} and δTPre\delta_T^{\mathrm{Pre}} allow analysis under different cost assumptions. The paper analyzes compute-optimal strategies by minimizing LSL_S subject to a fixed total FLOPs budget CC. Key findings for different scenarios include:

  • Full Cost (Teacher Pretraining + Inference; $\delta_T^{\mathrm{Lgt}=1, \delta_T^{\mathrm{Pre}=1}$): This scenario applies when a teacher must be trained specifically for a single student distillation task.
    • Optimal DSD_S and DTD_T scale as power laws of the total compute CC.
    • The optimal teacher size (NTN_T) initially grows with CC but eventually plateaus or even decreases at very high compute budgets. This counter-intuitive result arises because the teacher inference cost (F(NT)DSF(N_T)D_S) scales with DSD_S. As the optimal DSD_S grows with CC, it becomes more compute-efficient to achieve the target LTL_T by overtraining the teacher (using larger DTD_T relative to NTN_T) compared to the supervised compute-optimal point, thus limiting the growth of NTN_T and its associated inference cost.
    • The distribution of compute allocation shifts: teacher pretraining dominates at low CC, while student training and teacher inference costs become relatively larger at high CC, especially for smaller students.
  • Inference Cost Only ($\delta_T^{\mathrm{Lgt}=1, \delta_T^{\mathrm{Pre}=0}$): This scenario applies when a pre-trained teacher exists, but its inference cost during distillation is considered.
    • The optimal strategy involves selecting or fine-tuning a teacher that is overtrained (i.e., achieves the target LTL_T with the smallest possible NTN_T). This minimizes the inference cost component F(NT)DSF(N_T)D_S.
    • As total compute increases, the proportion allocated to teacher inference grows relative to student training.
  • Pretraining Cost Only ($\delta_T^{\mathrm{Lgt}=0, \delta_T^{\mathrm{Pre}=1}$): This scenario might apply if teacher inference is effectively free (e.g., logits are pre-computed and stored, or inference hardware is distinct and unconstrained), but the teacher still needs to be trained.
    • The optimal teacher follows the compute-optimal supervised training strategy (e.g., constant DT/NTD_T/N_T ratio, similar to Chinchilla). Since there's no penalty for large NTN_T during inference, the most efficient way to reach a target LTL_T is standard scaling.
    • As total compute increases, the proportion allocated to student training grows relative to teacher pretraining.
  • No Teacher Cost ($\delta_T^{\mathrm{Lgt}=0, \delta_T^{\mathrm{Pre}=0}$): This applies if a suitable teacher exists and inference costs are negligible or amortized. Here, the goal is simply to maximize DSD_S for the given student budget CS=3F(NS)DSC_S = 3F(N_S)D_S.

Optimal Distillation Recipes

Based on the compute optimization analysis, the paper proposes practical recipes depending on the availability and cost accounting of the teacher model:

  • Case 1: A Teacher Exists or its Cost is Amortized (δTPre=0\delta_T^{\mathrm{Pre}=0}): This is the most common scenario where distillation provides benefits.
    • Strategy: Given a student size NSN_S and a compute budget CC (primarily determining DSD_S, potentially constrained by F(NT)DSF(N_T)D_S if $\delta_T^{\mathrm{Lgt}=1$), distillation can outperform supervised pretraining up to a certain compute threshold.
    • Selecting Among Existing Teachers: Use the scaling law (Eq. 9) to predict LSL_S for each available teacher, characterized by (LT(i),NT(i))(L_T^{(i)}, N_T^{(i)}).
    • If the budget is defined by fixed distillation tokens DSD_S (and δTLgt=0\delta_T^{\mathrm{Lgt}=0}), choose the teacher minimizing LS(NS,DS,LT(i))L_S(N_S, D_S, L_T^{(i)}). The optimal LT∗L_T^* generally decreases as student size NSN_S increases.
    • If the budget is fixed compute CC and teacher inference cost is included (δTLgt=1\delta_T^{\mathrm{Lgt}=1}), the choice involves a trade-off: a better teacher (lower LTL_T) might have a larger NTN_T, increasing F(NT)F(N_T) and reducing the affordable DS=(C−F(NT)DS)/(3F(NS))D_S = (C - F(N_T)D_S) / (3F(N_S)). Overtrained teachers (smaller NTN_T for a given LTL_T) become preferable.
    • Ideal Teacher: If one could create the ideal teacher whose cost is amortized, it should typically be overtrained if inference costs (δTLgt=1\delta_T^{\mathrm{Lgt}=1}) are a factor, balancing LTL_T and NTN_T to optimize LSL_S under the compute constraint. If inference cost is negligible (δTLgt=0\delta_T^{\mathrm{Lgt}=0}), simply aim for the optimal LT∗L_T^*.
  • Case 2: Teacher Needs Training for a Single Student ($\delta_T^{\mathrm{Lgt}=1, \delta_T^{\mathrm{Pre}=1}$):
    • Key Finding: For a fixed total compute budget CC, supervised pretraining of the student model always achieves a better (lower) LSL_S than undergoing the full process of training a teacher and then distilling. The compute spent on teacher training and inference would yield better student performance if spent directly on supervised student training.
    • Recommendation: If the objective is the single best model of size NSN_S for a budget CC, and no suitable teacher exists whose cost can be amortized, perform supervised pretraining instead of distillation. Distillation is only compute-optimal in this context if the teacher pretraining cost is effectively zero for the specific student being trained (amortization).
    • Forced Distillation: If distillation must be performed under this full cost model, the optimal allocation involves balancing student training, teacher pretraining, and teacher inference costs, with the optimal teacher configuration shifting from compute-optimal (low CC) towards overtrained (high CC) as described previously.

Distillation vs. Supervised Pretraining Efficiency

The comparative efficiency of distillation and supervised pretraining hinges critically on the accounting of teacher costs. Distillation offers superior compute/data efficiency over supervised pretraining for a student of size NSN_S if and only if:

  1. Student Resource Constraint: The compute (CS≈3F(NS)DSC_S \approx 3F(N_S)D_S) or data (DSD_S) dedicated specifically to the student's training via distillation is below a predictable threshold. This threshold increases with student size NSN_S. Given sufficient student-specific resources, supervised learning will eventually match and surpass distillation.
  2. Amortized Teacher Pretraining Cost: The FLOPs required for teacher pretraining (δTPre3F(NT)DT\delta_T^{\mathrm{Pre}} 3F(N_T)D_T) are not counted towards the budget for the single student distillation run. This occurs when the teacher model already exists or when its training cost is distributed across multiple uses (e.g., distilling many students, direct deployment).

Crucially, if the full cost of pretraining a teacher from scratch is attributed solely to the distillation of one student model, supervised pretraining of that student model is always the more compute-efficient approach.

Practical Insights and Experimental Design

The study provides several recommendations for practitioners conducting distillation experiments:

  • Target Teacher Loss (LTL_T): Teacher quality primarily impacts the student via LTL_T. Experimental sweeps can be simplified by targeting specific LTL_T values rather than exhaustively exploring combinations of NTN_T and DTD_T.
  • Capacity Gap Awareness: Recognize that the gap depends on the ratio of teacher performance to student potential (LT/L~SL_T / \widetilde{L}_S). Very strong teachers can harm weak students. Monitor KL divergence during training or perform calibration analysis to diagnose issues.
  • Coefficient Estimation: To robustly fit the parameters of the distillation scaling law (Eq. 9), adapt the methodology from Chinchilla (Hoffmann et al., 2022). Design experiments that vary teacher and student characteristics systematically, such as fixing the teacher and running IsoFLOP experiments across student sizes, and vice-versa.
  • Leverage μ\muP: Maximal Update Parametrization facilitates hyperparameter transfer (e.g., learning rates) across model scales, a property shown to hold for distillation as well, simplifying large-scale studies.
  • Hyperparameter Defaults: Pure distillation (setting the interpolation weight λ=1\lambda=1 between teacher KL loss and ground-truth CE loss) is generally robust and near-optimal, especially if the teacher is reasonably strong. Unit temperature (Ï„=1\tau=1) is preferred as it preserves the teacher's probability distribution structure. Standard forward KL divergence is recommended.
  • Logit Truncation: Avoid aggressive Top-k or Top-p truncation of teacher logits for pure distillation (λ=1\lambda=1), as it significantly degrades performance. If storage/bandwidth constraints necessitate truncation (e.g., storing Top-k logits), consider using a non-zero weight for the ground-truth cross-entropy loss (λ<1\lambda < 1, e.g., λ≈0.7\lambda \approx 0.7) to compensate.
  • Fixed Aspect Ratio Models: Employing model architectures where the ratio of embedding dimension to the number of layers (dmodel/nlayersd_{\text{model}} / n_{\text{layers}}) is held constant simplifies the relationship between parameter count (NN) and FLOPs-per-token (F(N)F(N)), leading to more accurate compute estimations, especially for smaller models.

In conclusion, the distillation scaling laws presented offer a quantitative framework for predicting distillation performance and optimizing compute allocation. The key practical insight is the critical dependence on how teacher costs are accounted for: distillation is advantageous when teacher pretraining costs are amortized and student resources are constrained, whereas supervised learning is superior when the full cost of teacher creation is borne by a single distillation task.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Continue Learning

We haven't generated follow-up questions for this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 29 tweets with 1592 likes about this paper.

HackerNews

  1. Distillation Scaling Laws (5 points, 0 comments) 
  2. Distillation Scaling Laws (3 points, 0 comments)