Task-Aware Curriculum Distillation (TAPIR)
- The paper presents a multi-round distillation framework that transfers instruction-following capabilities from proprietary teacher models to smaller student LLMs.
- It introduces a task-aware curriculum utilizing Model Fitting Difficulty for smart seed selection and progressive difficulty escalation.
- Empirical evaluations show that TAPIR outperforms larger state-of-the-art baselines on benchmarks like AlpacaEval and MT-Bench with notable gains in roleplay, reasoning, and coding.
Task-Aware Curriculum Distillation for LLMs (TAPIR) is a principled multi-round distillation framework designed to transfer instruction-following abilities from proprietary oracle LLMs to smaller, open-weight student models. TAPIR leverages a curriculum planned around task-aware sampling and explicit difficulty escalation, targeting improved generalization, balanced task competencies, and effective knowledge transfer using less curated data compared to conventional distillation approaches. The methodology, metrics, and results outlined in the TAPIR framework highlight its ability to produce student models that outperform even larger state-of-the-art baselines on standard LLM instruction-following benchmarks (Yue et al., 2024).
1. Theoretical Framework and Problem Formalization
Let denote a pre-trained student LLM parameterized by , an oracle/teacher LLM (e.g., GPT-4, ChatGPT), and a judge LLM for scoring. The supervision corpus is , where is an instruction and the teacher response. TAPIR’s central objective is to minimize the weighted autoregressive cross-entropy loss over a refined training set: where assigns weights according to a rebalanced task distribution , and is a refined gold response. The curriculum is realized as multi-round distillation, with controlling the mix ratio of hard samples at round , and datasets escalated in size and difficulty over rounds.
2. TAPIR Workflow and Algorithmic Structure
The core TAPIR algorithm consists of two phases: (a) seed selection based on a Model Fitting Difficulty (MFD) metric; (b) iterative multi-round fine-tuning (curriculum), during which the student is progressively exposed to harder and more diverse instructions.
A summary of the algorithm:
- Seed Dataset Selection: Fine-tune on full to produce . Compute , where outputs a [1,10] score. Instructions with () become the hard-seed set .
- Multi-Round Curriculum: For rounds (), expand using the teacher , constrain expansion to the same set of task types, and mix with non-seed as per the current . Refined responses are used as training targets:
is incremented by per round, starting at , .
3. Difficulty Assessment and Task Distribution Control
TAPIR’s selection of “hard” instructions leverages the Model Fitting Difficulty (MFD) metric: where a higher MFD indicates greater student-teacher response disparity. The initial seed is chosen by thresholding this metric.
To achieve balanced coverage, each instruction is categorized by a DeBERTa-v3 classifier into one of 33 predefined task categories (), and sampled according to a user-designed task distribution , oversampling underrepresented but crucial areas (math, reasoning, coding).
Response refinement is performed by rewriting with task-specific chain-of-thought or stepwise prompts to generate , promoting higher-quality student learning targets.
4. Curriculum Escalation Mechanism
Curriculum planning is implemented via a deterministic schedule: The share of “hard” instructions () is thus systematically increased with each round, raising the expected difficulty and promoting robust student capabilities. Empirical analysis shows monotonic gains on both AlpacaEval 2.0 and MT-Bench as curriculum rounds proceed.
5. Experimental Setup and Implementation Specifics
Backbones tested include LLaMA2-7B (primary) and Qwen1.5-Chat spanning 1.8B–14B. Teacher and judge models are proprietary LLMs (ChatGPT or GPT4-turbo). Implementation details:
- Optimizer: AdamW; LR (LLaMA2) or (Qwen1.5); 3% warmup, weight decay off.
- Data/Batching: Batch 32, max sequence 2048 tokens, bfloat16 precision.
- Training: Three rounds (one epoch/round); 11K initial seeds, 30K after round 1, expanding by ~20K each round, total ~70K samples.
- Compute: 200 GPU-hours on A100 80GB.
6. Benchmarking, Results, and Comparative Evaluation
TAPIR was evaluated on AlpacaEval 2.0 (GPT4-turbo preference win rate) and MT-Bench (GPT4-turbo, across eight sub-tasks). Baseline comparisons included Stanford Alpaca 7B, Vicuna (v1.5, 7B/13B), LLaMA2-Chat (7B/13B), (s)Recycled WizardLM 7B, and Lion 7B.
Key empirical findings for LLaMA2-7B:
- TAPIR-7B-M achieves a 7.80% AlpacaEval win rate (exceeds LLaMA2-Chat 13B at 7.70%).
- MT-Bench overall: 6.74 (vs. 6.41 for LLaMA2-Chat 7B, 6.50 sRecycled WizardLM 7B).
- Largest improvements on Roleplay (+0.8), Reasoning (+2.4), Coding (+0.4), Humanities (+0.2).
For Qwen1.5, consistent 2–4 point gains were observed across all sizes (1.8B–14B). TAPIR does not degrade multiple-choice zero-shot abilities on off-distribution tasks as confirmed via the Open LLM Leaderboard.
Ablation studies parsed out the impact of each TAPIR component. Smart seed selection yields large gains over unfiltered data. Task-aware expansion and rewriting yields +1.22 on AlpacaEval relative to direct expansion. The curriculum strategy itself adds +0.75 AlpacaEval over a single round.
| Model Setting | #Train | AlpEval | MT-Bench |
|---|---|---|---|
| Full Implement (TAPIR-7B-M) | 70K | 7.80 | 6.74 |
| Single Round (no MCP) | 70K | 7.05 | 6.71 |
| Direct Expansion (no tasks/RW) | 70K | 5.83 | 6.43 |
| Seed Alpaca (RW only) | 11K | 5.17 | 6.28 |
| Seed Alpaca (no RW) | 11K | 4.76 | 6.23 |
| Full Alpaca (52K) | 52K | 2.28 | 5.07 |
7. Contributions, Limitations, and Prospective Work
TAPIR’s principal methodological advances include:
- Model Fitting Difficulty metric for hard sample selection.
- Task-aware data rebalancing explicitly prioritizing underrepresented yet vital disciplines.
- Explicit, data-efficient, multi-round curriculum construction improving generalization and convergence.
Strengths:
- Consistently outperforms LLaMA2-Chat 13B with ~70K distilled samples.
- Robust across student sizes and architectures.
- Ablations isolate the utility of each design choice.
Limitations identified:
- Reliance on API-based proprietary teacher and judge LLMs (potential cost, repeatability, and bias constraints).
- Significant compute requirement (200 GPU-hours for the full curriculum).
- Task classifier (DeBERTa-v3) misclassifies ~8% of instances, potentially distorting the target distribution.
Future exploration is suggested in:
- Incorporating human- or semi-supervised task classifiers.
- Developing adaptive schedules or more granular curriculum strategies.
- Extending TAPIR to specialized domains with unique task distributions (e.g., legal, medical contexts).
TAPIR establishes a rigorously evaluated, task- and difficulty-aware approach to LLM distillation that mitigates data and performance imbalances endemic in previous open distillation pipelines (Yue et al., 2024).