- The paper proposes a novel unbiased Rao–Blackwellized estimator that decomposes KL divergence into step-wise conditional terms for reduced variance.
- It theoretically proves and empirically validates lower variance compared to Monte Carlo methods, ensuring stable gradient estimates in RLHF applications.
- The method retains the same computational complexity as standard approaches while guaranteeing non-negative outputs for reliable KL-constrained optimization.
This paper addresses the challenge of accurately and stably estimating the Kullback--Leibler (KL) divergence between LMs. Estimating this divergence is crucial for various downstream tasks in natural language processing, including reinforcement learning from human feedback (RLHF), model interpretability, knowledge distillation, and evaluation metrics.
Computing the exact KL divergence, defined as $\kl(\plm \mid\mid \qlm) = \sum_{\str \in \kleene{\alphabet} \plm(\str) \log \frac{\plm(\str)}{\qlm(\str)}}$, is generally intractable for neural LLMs because it requires summing over the infinite set of possible strings. Practitioners typically resort to sampling-based estimators, most commonly the naive Monte Carlo (MC) estimator:
$\klmc = \frac{1}{M} \sum_{m=1}^M \log \frac{\plm(\yvar^{(m)})}{\qlm(\yvar^{(m)})}$, where $\yvar^{(m)} \sim \plm$.
While unbiased, the MC estimator suffers from high variance and can produce undesirable negative estimates, which can be problematic in applications like RLHF where KL is used as a non-negative regularization term. Existing alternatives like a control variate approach proposed by Schulman (klblog), often used in RLHF libraries, do not guarantee variance reduction and can even exhibit unbounded variance in practice.
The paper introduces a novel estimator based on Rao--Blackwellization, a statistical technique for variance reduction. The key idea is to decompose the KL divergence for LLMs into a sum of step-wise KL divergences between conditional distributions:
$\kl(\plm \mid\mid \qlm) = \sum_{\str \in \Sigma^*} \prefixprob(\str) \kl\mleft(\pprob \mleft(\cdot \mid \str \mright) \mid\mid \qpprob \mleft(\cdot \mid \str \mright)\mright)$, where $\pprob(\cdot \mid \str)$ is the conditional distribution over the next token given prefix $\str$.
The Rao--Blackwellized (RB) estimator leverages this decomposition. Instead of averaging the log-ratio of full sequence probabilities (like MC), it averages the exact step-wise conditional KL divergences, conditioned on the prefixes sampled from $\plm$.
The RB estimator is defined as:
$\rbestimator = \frac{1}{M}\sum_{m=1}^M \sum_{n=1}^\infty \E_{\Yb_n \sim \pprob(\cdot \mid \yvarb^{(m)}_{<n})} \left[\log \frac{\pprob(\Yb_n \mid \yvarb^{(m)}_{< n})}{\qpprob(\Yb_n \mid \yvarb^{(m)}_{< n})}\right]$.
In practice, this sum is truncated based on the sampled sequence length. The inner expectation, which is the exact KL divergence between the conditional distributions over the next token given a prefix, can be computed efficiently since it's a sum over the finite alphabet size $|\alphabet|$.
The paper theoretically proves that the RB estimator is unbiased and has a variance less than or equal to that of the standard Monte Carlo estimator [(2504.10637), Theorem 2.2]. Crucially, the RB estimator is guaranteed to be non-negative because it sums non-negative conditional KL terms. The computational complexity of the RB estimator is shown to be the same as the MC estimator, $\bigo{M \cdot \mathbb{E}[N] \cdot d |\alphabet|}$, where M is the number of samples, E[N] is the expected sequence length, d is the model dimension, and $|\alphabet|$ is the alphabet size. The dominant cost is typically model inference for sampling and computing probabilities, not the additional computation for the RB estimator.
The paper also derives an analogous Rao--Blackwellized estimator for the gradient of the KL divergence, which is essential for optimizing KL-regularized objectives like in RLHF. The RB gradient estimator is proven to have a lower or equal expected squared error (related to variance) compared to the standard MC gradient estimator [(2504.10637), Theorem 4.2].
For practical implementation in a deep learning framework, the RB estimator involves:
- Sampling sequences $\yvar^{(m)}$ from the policy $\policy$.
- For each sequence m and each position n:
- Compute the conditional probability distributions $\pprob(\cdot \mid \yvar^{(m)}_{<n})$ and $\qpprob(\cdot \mid \yvar^{(m)}_{<n})$ over the next token ($\alphabar$). This requires a forward pass through the models up to prefix $\yvar^{(m)}_{<n}$ and obtaining the logits/probabilities for all tokens in $\alphabar$.
- Compute the exact conditional KL divergence: $\sum_{\yb \in \alphabar} \pprob(\yb \mid \yvar^{(m)}_{< n}) \log \frac{\pprob(\yb \mid \yvar^{(m)}_{< n})}{\qpprob(\yb \mid \yvar^{(m)}_{< n})}$.
- Sum these conditional KL terms over n for each sample m, and then average over m.
Pseudocode for implementing the core RB KL estimation for a single sample might look like:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
def estimate_kl_rb(policy_model, ref_model, prompt):
sequence = policy_model.generate(prompt) # sample a sequence
kl_estimate = 0.0
current_prefix_ids = prompt_token_ids
for token_id in sequence:
# Get conditional probabilities for next token
policy_probs = policy_model.get_next_token_probs(current_prefix_ids) # probs over alphabet
ref_probs = ref_model.get_next_token_probs(current_prefix_ids) # probs over alphabet
# Compute exact conditional KL for this step
conditional_kl = 0.0
for vocab_id in range(vocab_size):
p = policy_probs[vocab_id]
q = ref_probs[vocab_id]
if p > 0 and q > 0: # Handle log(0)
conditional_kl += p * math.log(p / q)
# if p > 0 and q == 0, KL is infinite, assuming finite KL
# if p == 0, term is 0
kl_estimate += conditional_kl
current_prefix_ids.append(token_id)
if token_id == eos_token_id:
break # sequence ended
return kl_estimate
|
For gradient estimation in RLHF, the expectation over $\Yb_n$ in the RB gradient formula requires computing gradients through this expectation.
Empirical evaluation on a sentiment-controlled fine-tuning task (GPT-2 IMDB) confirms the theoretical advantages. The RB estimator demonstrates significantly lower standard deviation and greater stability compared to MC and the tested control variate estimator, especially on challenging or adversarial prompts where other methods exhibit high variance. When used in the RLHF training loop (specifically with the RLOO algorithm), the RB gradient estimator leads to more stable training runs and produces models that appear more frequently on the Pareto frontier balancing reward maximization and KL divergence minimization. This suggests that reducing variance in KL estimation improves the reliability and effectiveness of KL-constrained optimization methods.
The paper also briefly discusses off-policy KL estimation using importance sampling and notes a common deviation in trust-region RL algorithms like PPO, where a specific biased estimator related to a Bregman divergence is used, cautioning against naive application of RB to this biased estimator.
In summary, the paper provides a theoretically sound and empirically validated Rao--Blackwellized estimator for KL divergence and its gradient between LLMs. This estimator offers guaranteed variance reduction over standard Monte Carlo methods without added computational overhead, leading to more stable and reliable estimates and improved performance in applications like RLHF. The work highlights a practical improvement for tasks requiring KL divergence estimation in the context of LLMs.