Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
134 tokens/sec
GPT-4o
10 tokens/sec
Gemini 2.5 Pro Pro
47 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Multi-Level Optimal Transport for Universal Cross-Tokenizer Knowledge Distillation on Language Models (2412.14528v2)

Published 19 Dec 2024 in cs.CL

Abstract: Knowledge distillation (KD) has become a prevalent technique for compressing LLMs. Existing KD methods are constrained by the need for identical tokenizers (i.e., vocabularies) between teacher and student models, limiting their versatility in handling LLMs of different architecture families. In this paper, we introduce the Multi-Level Optimal Transport (MultiLevelOT), a novel approach that advances the optimal transport for universal cross-tokenizer knowledge distillation. Our method aligns the logit distributions of the teacher and the student at both token and sequence levels using diverse cost matrices, eliminating the need for dimensional or token-by-token correspondence. At the token level, MultiLevelOT integrates both global and local information by jointly optimizing all tokens within a sequence to enhance robustness. At the sequence level, we efficiently capture complex distribution structures of logits via the Sinkhorn distance, which approximates the Wasserstein distance for divergence measures. Extensive experiments on tasks such as extractive QA, generative QA, and summarization demonstrate that the MultiLevelOT outperforms state-of-the-art cross-tokenizer KD methods under various settings. Our approach is robust to different student and teacher models across model families, architectures, and parameter sizes. Codes and models are available at https://github.com/2018cx/Multi-Level-OT.

Summary

  • The paper introduces MultiLevelOT, a novel method for cross-tokenizer knowledge distillation in language models by aligning logit distributions at both token and sequence levels.
  • MultiLevelOT overcomes the limitation of needing identical tokenizers and integrates global sequence context with local token information using optimal transport.
  • Validated across diverse model architectures and sizes, MultiLevelOT outperforms existing cross-tokenizer distillation techniques on various NLP tasks like QA and summarization.

The paper "Multi-Level Optimal Transport for Universal Cross-Tokenizer Knowledge Distillation on LLMs" (2412.14528) introduces MultiLevelOT, a novel approach to knowledge distillation (KD) that addresses the limitations of existing methods by enabling cross-tokenizer KD. This method aligns logit distributions between teacher and student models at both the token and sequence levels, eliminating the need for identical tokenizers.

Key Innovations of MultiLevelOT

  • Addressing Tokenizer Limitations: MultiLevelOT overcomes the constraint of identical tokenizers, a common requirement in existing KD methods, by aligning logit distributions at both token and sequence levels. This eliminates the necessity for direct dimensional or token-by-token correspondence between teacher and student models.
  • Sequence-Aware Token-Level Optimal Transport: The method integrates global and local information by jointly optimizing all tokens within a sequence, enhancing robustness. This contrasts with methods like ULD, which focus solely on local information, neglecting global distributional properties.
  • Sequence-Level Optimal Transport: MultiLevelOT captures complex logit distribution structures at the sequence level using the Sinkhorn distance, approximating the Wasserstein distance. This is crucial for addressing token order misalignment caused by differing tokenizations.

Technical Details of MultiLevelOT

MultiLevelOT employs several key techniques to achieve effective cross-tokenizer knowledge distillation:

  • Diverse Cost Matrices: The method uses two types of cost matrices—absolute difference and logarithmic form—to capture fine-grained token-wise nuances and holistic sequence-scale context, enhancing the knowledge transfer process.
  • Sinkhorn Distance: The Sinkhorn distance is used to approximate the Wasserstein distance, enabling the capture of complex distribution structures of logits at the sequence level, which is vital for addressing token order misalignment.
    • The Sinkhorn distance is computed as follows:
    • Given two discrete probability measures μ=i=1maiδxi\mu = \sum_{i=1}^m a_i \delta_{x_i} and ν=j=1nbjδyj\nu = \sum_{j=1}^n b_j \delta_{y_j} with a,b>0a, b > 0 and iai=jbj=1\sum_i a_i = \sum_j b_j = 1, the squared Wasserstein distance W22(μ,ν)W_2^2(\mu, \nu) is defined as:

      W22(μ,ν)=infγΠ(a,b)i=1mj=1nγijc(xi,yj)W_2^2(\mu, \nu) = \inf_{\gamma \in \Pi(a,b)} \sum_{i=1}^m \sum_{j=1}^n \gamma_{ij} c(x_i, y_j)

      where Π(a,b)\Pi(a,b) denotes the set of transport plans with marginals aa and bb, and c(xi,yj)=xiyj2c(x_i, y_j) = \|x_i - y_j\|^2 is the cost function. The Sinkhorn distance adds an entropic regularization term to the Wasserstein distance, making it more computationally tractable:

      Wϵ(μ,ν)=infγΠ(a,b)i=1mj=1nγijc(xi,yj)+ϵH(γ)W_{\epsilon}(\mu, \nu) = \inf_{\gamma \in \Pi(a,b)} \sum_{i=1}^m \sum_{j=1}^n \gamma_{ij} c(x_i, y_j) + \epsilon \mathbb{H}(\gamma)

      where H(γ)=i,jγijlogγij\mathbb{H}(\gamma) = - \sum_{i,j} \gamma_{ij} \log \gamma_{ij} is the entropy of γ\gamma, and ϵ>0\epsilon > 0 is a regularization parameter.

  • Joint Optimization: MultiLevelOT jointly optimizes all tokens within a sequence to integrate both global and local information, enhancing robustness. This is particularly effective in capturing the relationships between tokens and their context within the sequence.
  • Cost Matrix Formulation: MultiLevelOT uses two types of cost matrices:
    • Absolute Difference: Cij=litljsC_{ij} = |l_i^t - l_j^s|
    • Logarithmic Form: $C_{ij} = -\log(\text{cosine_similarity}(l_i^t, l_j^s))$
    • where litl_i^t and ljsl_j^s are the logits of the teacher and student models, respectively.

Advantages and Performance

  • No Architectural Constraints: MultiLevelOT does not require additional modules or modifications to output formats specific to NLP tasks, making it versatile and easy to implement.
  • Robustness and Generalizability: The method is robust across different student and teacher models, families, architectures, and parameter sizes. It has been validated on various students across families, architectures, and sizes, and with diverse teachers, showcasing its robustness.
  • Superior Performance: MultiLevelOT outperforms state-of-the-art CTKD methods on tasks such as extractive QA, generative QA, and summarization under both labeled and unlabeled distillation settings. Ablation studies and hyper-parameter tuning have demonstrated the effectiveness of each component.
X Twitter Logo Streamline Icon: https://streamlinehq.com

Tweets