Adaptive Joint Decoding Loss
- The paper presents a novel loss function that jointly optimizes insertion and unmasking policies to fine-tune any-length discrete diffusion models under reward-guided objectives.
- It leverages weighted cross-entropy and Radon–Nikodym derivative reweighting to adaptively schedule decoding steps, ensuring convergence to a reward-tilted distribution.
- Empirical evaluations in tasks like drug design and language reasoning demonstrate notable improvements in sample quality and reward alignment compared to conventional methods.
Adaptive Joint Decoding (AJD) Loss is a loss function and training principle developed for the fine-tuning of any-length discrete diffusion models under reward-guided objectives. It enables joint optimization of insertion and unmasking policies in sequence generation tasks, and incorporates adaptive scheduling driven by learned quality estimates for both token and gap insertions. AJD Loss guarantees convergence to a reward-tilted sequence distribution, allowing diffusion-based generative models to align closely with desired metrics or performance rewards, while supporting flexible, variable-length sequence outputs (Tang et al., 11 Jun 2026).
1. Formalization of Adaptive Joint Decoding Loss
The Adaptive Joint Decoding loss framework addresses the task of fine-tuning any-length discrete diffusion models such that their decoding paths (trajectories through joint insertion and unmasking operations) generate samples from an intractable, reward-tilted target distribution. This is achieved by optimizing a weighted cross-entropy loss over continuous-time Markov chain (CTMC) path measures, in which sampled trajectories are reweighted using the Radon–Nikodym (RN) derivative, adapting each trajectory’s contribution according to its importance relative to the reward objective. The loss function decomposes into four tractable terms:
- Denoising cross-entropy for unmasking,
- Bregman-divergence loss for insertion counts,
- Binary cross-entropy loss for unmasking quality (UQL),
- Binary cross-entropy loss for insertion quality (IQL).
This joint structure supports provable convergence to the reward-tilted distribution and enables an adaptive inference schedule informed by learned quality scores for tokens and gaps (Tang et al., 11 Jun 2026).
2. Path Measures, Radon–Nikodym Derivative, and Reward-Tilted Objectives
The formulation relies on defining a CTMC trajectory , where is the set of all variable-length sequences over vocabulary . The reference (pre-trained) path measure, , is induced by unmasking and insertion rates , while the reward-tilted measure is
with as the partition function and the reward. The marginal distribution at is thus . The RN derivative between 0 and a fine-tuned measure 1 quantifies the relative importance of each trajectory under the two measures and includes terms for reward increments, relative log-rate ratios at jump times, and an integral correcting for differences in exit rates.
3. Decomposition and Computation of the AJD Loss
The AJD loss is expressed as a cross-entropy from the fine-tuned policy 2 to the reward-tilted measure 3:
4
In practical optimization, off-policy sampling from a fixed measure 5 is used, yielding an expectation:
6
where 7 parameterizes the balance between policy and quality learning. Each loss component corresponds to a specific aspect of the decoding process and its inherent uncertainties; unmasking and insertion losses are supervised on the ground-truth sequence, whereas UQL and IQL train auxiliary heads to predict the quality of these actions.
4. Joint Policy Optimization and Adaptive Inference
The AJD loss enables simultaneous optimization of the generative policy and the adaptive inference schedule:
- Unmasking Quality (8): Defined as 9, trained via 0. At inference, tokens with low predicted quality 1 are re-masked to control error propagation.
- Insertion Quality (2): Given by 3, trained via 4. Low-quality masks are deleted during insertion to maintain sequence length consistency.
- Adaptive Schedule: In each step of decoding: (1) parallel sampling of unmasking jumps followed by re-masking the lowest-quality tokens (justified by maximization of joint correctness); (2) parallel insertions with deletion of the lowest-quality gaps to target the correct length.
These mechanisms couple policy learning to quality-aware scheduling, allowing diffusion models to adaptively determine not only the next token/gap, but also which partial generations should be revised at each stage.
5. Theoretical Convergence Guarantees
Minimizing AJD loss via the weighted cross-entropy is theoretically equivalent to minimizing the KL divergence, 5, under off-policy sampling. The unique minimizer consists of joint rates 6 inducing a path measure matching 7. Consequently, the model's marginal at 8 converges in distribution to the reward-tilted target 9. This provides a direct probabilistic guarantee for reward-guided fine-tuning of discrete diffusion models (Tang et al., 11 Jun 2026).
6. Algorithmic Implementation
AJD loss is optimized using a replay buffer-based approach with alternating updates to policy (0) and quality (1) parameters. The main process is summarized as follows:
| Phase | Operation | Frequency/Hints |
|---|---|---|
| Buffer generation | Adaptive any-length inference to generate 2 | every 3 steps |
| Policy update | Compute 4, 5 | 6 epochs, 7 frozen |
| Quality update | Compute 8, 9 | 0 epochs, 1 frozen |
Key hyperparameters include buffer size (2), number of trajectory replicates (3), reward scale (4), alternation frequency (5), and separate learning rates for 6 and 7. Off-policy sampling from frozen models is used for stable collection of training trajectories, with policy and quality heads updated in alternation to ensure decoupled gradient flows.
7. Empirical Evaluation and Benefits
Adaptive Joint Decoding achieves notable improvements across several domains:
- Drug-like molecule design (SAFE strings): QED increases from 0.641 to 0.762, synthetic accessibility (SA) decreases from 3.40 to 2.87, and the overall composite metric (valid, unique, drug-like, synthesizable) improves from 44% to 71%. Generation remains any-length with validity/diversity comparable to SOTA fixed-length models.
- Therapeutic peptide multi-objective design: Simultaneous improvements observed in binding affinity, solubility, non-hemolysis, and a validity increase from approximately 10% to 48%, with superior performance over fixed-length and RL-based methods.
- Language reasoning (GSM8K, code infill): GSM8K Pass@1 rises from 35.7% to 60.9% (128 steps), HumanEval-infill exact match rises from 44.1% to 49.4% at 128 steps and 57.1% at 1024 steps, demonstrating improved reward alignment and faster inference.
The AJD loss framework thus enables flexible, reward-driven generation in discrete diffusion, supporting robust sample quality and adaptive, data-driven inference procedures (Tang et al., 11 Jun 2026).