Papers
Topics
Authors
Recent
Assistant
AI Research Assistant
Well-researched responses based on relevant abstracts and paper content.
Custom Instructions Pro
Preferences or requirements that you'd like Emergent Mind to consider when generating responses.
Gemini 2.5 Flash
Gemini 2.5 Flash 77 tok/s
Gemini 2.5 Pro 56 tok/s Pro
GPT-5 Medium 33 tok/s Pro
GPT-5 High 21 tok/s Pro
GPT-4o 107 tok/s Pro
Kimi K2 196 tok/s Pro
GPT OSS 120B 436 tok/s Pro
Claude Sonnet 4.5 34 tok/s Pro
2000 character limit reached

Better Estimation of the KL Divergence Between Language Models (2504.10637v2)

Published 14 Apr 2025 in cs.CL, cs.AI, and cs.LG

Abstract: Estimating the Kullback--Leibler (KL) divergence between LLMs has many applications, e.g., reinforcement learning from human feedback (RLHF), interpretability, and knowledge distillation. However, computing the exact KL divergence between two arbitrary LLMs is intractable. Thus, practitioners often resort to the use of sampling-based estimators. While it is easy to fashion a simple Monte Carlo (MC) estimator that provides an unbiased estimate of the KL divergence between LLMs, this estimator notoriously suffers from high variance, and can even result in a negative estimate of the KL divergence, a non-negative quantity. In this paper, we introduce a Rao--Blackwellized estimator that is also unbiased and provably has variance less than or equal to that of the standard Monte Carlo estimator. In an empirical study on sentiment-controlled fine-tuning, we show that our estimator provides more stable KL estimates and reduces variance substantially in practice. Additionally, we derive an analogous Rao--Blackwellized estimator of the gradient of the KL divergence, which leads to more stable training and produces models that more frequently appear on the Pareto frontier of reward vs. KL compared to the ones trained with the MC estimator of the gradient.

Summary

  • The paper proposes a novel unbiased Rao–Blackwellized estimator that decomposes KL divergence into step-wise conditional terms for reduced variance.
  • It theoretically proves and empirically validates lower variance compared to Monte Carlo methods, ensuring stable gradient estimates in RLHF applications.
  • The method retains the same computational complexity as standard approaches while guaranteeing non-negative outputs for reliable KL-constrained optimization.

This paper addresses the challenge of accurately and stably estimating the Kullback--Leibler (KL) divergence between LMs. Estimating this divergence is crucial for various downstream tasks in natural language processing, including reinforcement learning from human feedback (RLHF), model interpretability, knowledge distillation, and evaluation metrics.

Computing the exact KL divergence, defined as $\kl(\plm \mid\mid \qlm) = \sum_{\str \in \kleene{\alphabet} \plm(\str) \log \frac{\plm(\str)}{\qlm(\str)}}$, is generally intractable for neural LLMs because it requires summing over the infinite set of possible strings. Practitioners typically resort to sampling-based estimators, most commonly the naive Monte Carlo (MC) estimator: $\klmc = \frac{1}{M} \sum_{m=1}^M \log \frac{\plm(\yvar^{(m)})}{\qlm(\yvar^{(m)})}$, where $\yvar^{(m)} \sim \plm$. While unbiased, the MC estimator suffers from high variance and can produce undesirable negative estimates, which can be problematic in applications like RLHF where KL is used as a non-negative regularization term. Existing alternatives like a control variate approach proposed by Schulman (klblog), often used in RLHF libraries, do not guarantee variance reduction and can even exhibit unbounded variance in practice.

The paper introduces a novel estimator based on Rao--Blackwellization, a statistical technique for variance reduction. The key idea is to decompose the KL divergence for LLMs into a sum of step-wise KL divergences between conditional distributions: $\kl(\plm \mid\mid \qlm) = \sum_{\str \in \Sigma^*} \prefixprob(\str) \kl\mleft(\pprob \mleft(\cdot \mid \str \mright) \mid\mid \qpprob \mleft(\cdot \mid \str \mright)\mright)$, where $\pprob(\cdot \mid \str)$ is the conditional distribution over the next token given prefix $\str$. The Rao--Blackwellized (RB) estimator leverages this decomposition. Instead of averaging the log-ratio of full sequence probabilities (like MC), it averages the exact step-wise conditional KL divergences, conditioned on the prefixes sampled from $\plm$. The RB estimator is defined as: $\rbestimator = \frac{1}{M}\sum_{m=1}^M \sum_{n=1}^\infty \E_{\Yb_n \sim \pprob(\cdot \mid \yvarb^{(m)}_{<n})} \left[\log \frac{\pprob(\Yb_n \mid \yvarb^{(m)}_{< n})}{\qpprob(\Yb_n \mid \yvarb^{(m)}_{< n})}\right]$. In practice, this sum is truncated based on the sampled sequence length. The inner expectation, which is the exact KL divergence between the conditional distributions over the next token given a prefix, can be computed efficiently since it's a sum over the finite alphabet size $|\alphabet|$.

The paper theoretically proves that the RB estimator is unbiased and has a variance less than or equal to that of the standard Monte Carlo estimator [(2504.10637), Theorem 2.2]. Crucially, the RB estimator is guaranteed to be non-negative because it sums non-negative conditional KL terms. The computational complexity of the RB estimator is shown to be the same as the MC estimator, $\bigo{M \cdot \mathbb{E}[N] \cdot d |\alphabet|}$, where MM is the number of samples, E[N]\mathbb{E}[N] is the expected sequence length, dd is the model dimension, and $|\alphabet|$ is the alphabet size. The dominant cost is typically model inference for sampling and computing probabilities, not the additional computation for the RB estimator.

The paper also derives an analogous Rao--Blackwellized estimator for the gradient of the KL divergence, which is essential for optimizing KL-regularized objectives like in RLHF. The RB gradient estimator is proven to have a lower or equal expected squared error (related to variance) compared to the standard MC gradient estimator [(2504.10637), Theorem 4.2].

For practical implementation in a deep learning framework, the RB estimator involves:

  1. Sampling sequences $\yvar^{(m)}$ from the policy $\policy$.
  2. For each sequence mm and each position nn:
    • Compute the conditional probability distributions $\pprob(\cdot \mid \yvar^{(m)}_{<n})$ and $\qpprob(\cdot \mid \yvar^{(m)}_{<n})$ over the next token ($\alphabar$). This requires a forward pass through the models up to prefix $\yvar^{(m)}_{<n}$ and obtaining the logits/probabilities for all tokens in $\alphabar$.
    • Compute the exact conditional KL divergence: $\sum_{\yb \in \alphabar} \pprob(\yb \mid \yvar^{(m)}_{< n}) \log \frac{\pprob(\yb \mid \yvar^{(m)}_{< n})}{\qpprob(\yb \mid \yvar^{(m)}_{< n})}$.
  3. Sum these conditional KL terms over nn for each sample mm, and then average over mm.

Pseudocode for implementing the core RB KL estimation for a single sample might look like:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def estimate_kl_rb(policy_model, ref_model, prompt):
    sequence = policy_model.generate(prompt) # sample a sequence
    kl_estimate = 0.0
    current_prefix_ids = prompt_token_ids

    for token_id in sequence:
        # Get conditional probabilities for next token
        policy_probs = policy_model.get_next_token_probs(current_prefix_ids) # probs over alphabet
        ref_probs = ref_model.get_next_token_probs(current_prefix_ids)       # probs over alphabet

        # Compute exact conditional KL for this step
        conditional_kl = 0.0
        for vocab_id in range(vocab_size):
            p = policy_probs[vocab_id]
            q = ref_probs[vocab_id]
            if p > 0 and q > 0: # Handle log(0)
                 conditional_kl += p * math.log(p / q)
            # if p > 0 and q == 0, KL is infinite, assuming finite KL
            # if p == 0, term is 0

        kl_estimate += conditional_kl
        current_prefix_ids.append(token_id)
        if token_id == eos_token_id:
            break # sequence ended

    return kl_estimate
For gradient estimation in RLHF, the expectation over $\Yb_n$ in the RB gradient formula requires computing gradients through this expectation.

Empirical evaluation on a sentiment-controlled fine-tuning task (GPT-2 IMDB) confirms the theoretical advantages. The RB estimator demonstrates significantly lower standard deviation and greater stability compared to MC and the tested control variate estimator, especially on challenging or adversarial prompts where other methods exhibit high variance. When used in the RLHF training loop (specifically with the RLOO algorithm), the RB gradient estimator leads to more stable training runs and produces models that appear more frequently on the Pareto frontier balancing reward maximization and KL divergence minimization. This suggests that reducing variance in KL estimation improves the reliability and effectiveness of KL-constrained optimization methods.

The paper also briefly discusses off-policy KL estimation using importance sampling and notes a common deviation in trust-region RL algorithms like PPO, where a specific biased estimator related to a Bregman divergence is used, cautioning against naive application of RB to this biased estimator.

In summary, the paper provides a theoretically sound and empirically validated Rao--Blackwellized estimator for KL divergence and its gradient between LLMs. This estimator offers guaranteed variance reduction over standard Monte Carlo methods without added computational overhead, leading to more stable and reliable estimates and improved performance in applications like RLHF. The work highlights a practical improvement for tasks requiring KL divergence estimation in the context of LLMs.

Lightbulb Streamline Icon: https://streamlinehq.com

Continue Learning

We haven't generated follow-up questions for this paper yet.

List To Do Tasks Checklist Streamline Icon: https://streamlinehq.com

Collections

Sign up for free to add this paper to one or more collections.

X Twitter Logo Streamline Icon: https://streamlinehq.com

Tweets

This paper has been mentioned in 2 posts and received 124 likes.