FUJI-LDDMs: Joint Latent Diffusion Models
- FUJI-LDDMs are fully joint latent discrete diffusion models that integrate masked token updates with a continuous latent channel to capture global context.
- The model architecture interleaves masked discrete diffusion with Gaussian-scheduled latent diffusion, enabling robust joint denoising and efficient context propagation.
- Empirical results demonstrate that FUJI-LDDMs achieve lower perplexity and enhanced fluency, leading to improved performance in language and reasoning tasks.
FUJI-LDDMs (Fully Joint Latent Discrete Diffusion Models) are a class of discrete diffusion models that couple masked discrete diffusion for language or categorical data with a continuous latent diffusion channel. Developed to address fundamental limitations in masked denoising diffusion models—specifically, the tendency for token-level updates to factorize independently across sequence positions—FUJI-LDDMs introduce a latent channel to propagate joint contextual information, thereby improving fidelity and sample efficiency, especially in parallel or few-step generation settings (Shariatian et al., 20 Oct 2025, Jo et al., 22 Oct 2025).
1. Motivation and Background
Masked discrete diffusion models (MDLMs), as applied to language and categorical data, rely on a sequence of noising and denoising operations wherein tokens are incrementally masked and then reconstructed. In these models, reverse (denoising) transitions are typically performed independently at each sequence position, which means that complex dependencies among tokens (such as global syntax or semantic structure) are not well captured, particularly when many tokens are unmasked per step. This results in lack of coherence and degraded quality, particularly for non-autoregressive, parallel generation.
FUJI-LDDMs are formulated to address these shortcomings by introducing a continuous latent channel that is diffused in tandem with the masked discrete variables. The latent embeddings carry cross-token dependencies, disambiguate generative ambiguities, and allow the model to maintain and propagate more nuanced context information across denoising steps.
2. Generative Mechanism and Model Architecture
In FUJI-LDDMs, the state at each diffusion timestep is a pair , where is the (potentially masked) discrete sequence and is a vector of continuous latent embeddings. The generative process consists of two interleaved Markov chains:
- Forward (noising) process: Both the discrete tokens and the latent are progressively corrupted.
- The data channel () follows a masked diffusion trajectory.
- The latent channel () is diffused according to a Gaussian schedule.
- Reverse (denoising) process: At each reverse timestep, the joint model (typically a Transformer with multi-modal projections and shared attention) predicts both:
- The distribution over unmasked tokens , conditioned on the current tokens and latent .
- The denoised latent , again jointly conditioned.
Critically, the reverse transitions factorize across positions in each channel but are parameterized jointly: the model updates all tokens and the latent embedding in a fully coupled manner, sharing contextual information through attention layers. This is in contrast to alternative architectures where discrete and continuous channels are resolved sequentially or in isolation.
3. Mathematical Formulation and ELBO Objective
The learning objective is based on a variational Evidence Lower Bound (ELBO) over the joint diffusion process:
where:
- : predicted categorical distribution over tokens at time ,
- : predicted denoised latent,
- , : ground truth sequence and latent embedding (from an encoder ),
- : loss weighting coefficients derived from the ELBO/KL structure.
Continuous latent reconstruction employs a standard Gaussian loss with timestep-specific weighting. Discrete reconstruction loss arises from the masked diffusion dynamics; explicit KL and reconstruction formulas (see Table 1 in (Shariatian et al., 20 Oct 2025)) support training.
Initialization of uses the all-mask token state and a Gaussian latent with fixed variance, ensuring the generative chain starts from maximum entropy.
4. Design Principles and Training Considerations
Fundamental to FUJI-LDDM performance are several design choices:
- Joint Self-Attention: A model backbone that processes discretized input and continuous latent via shared multi-head self-attention, enabling cross-modal information flow.
- Fixed Latent Variance and Normalization: To prevent pathologies where the encoder "cheats" via latent magnitude, latent vectors are normalized () and the encoder variance is controlled (e.g., ).
- Two-Stage Curriculum: Early in training, the latent channel loss is downweighted (), allowing the token channel to stabilize first before ramping up latent influence.
- Choice of Encoder: Frozen pre-trained encoders (e.g., Qwen3-Embedding) may enhance performance by providing fixed semantic representations, while learned encoders can adapt to specific datasets.
- Efficient Decoding: FUJI-LDDMs are particularly effective at low sampling budgets (i.e., when fewer denoising steps are taken), as the latent channel provides global guidance that counteracts the loss of context inherent in simultaneous token unmasking.
5. Loopholing and Deterministic Latent Pathways
An extension of FUJI-LDDMs, described as "Loopholing" in (Jo et al., 22 Oct 2025), introduces an explicit deterministic latent pathway that bypasses the "sampling wall"—the collapse of distributional information at each categorical sample. In this framework:
- The forward pass at each timestep computes both the projected one-hot tokens and a deterministic latent ().
- At the next timestep, —after normalization—is added to the embedded tokens before the sequence is reprocessed. The latent propagates distributional context that would otherwise be lost upon sampling.
- A self-conditioning strategy is employed: at each training step, the model first predicts with zero latent input, then conditions a second pass on the first's output (with gradients stopped), allowing efficient random timestep training without full unrolling.
This mechanism ensures that context is preserved across denoising steps, mitigating idle steps and oscillatory behavior that afflict prior masked diffusion models.
6. Experimental Performance and Empirical Findings
FUJI-LDDMs achieve consistent improvements on unconditional generation across benchmarks:
- Language Modeling (LM1B, OpenWebText): Lower validation and generative perplexity relative to masked discrete diffusion baselines (MDLM). At reduced sampling steps, FUJI-LDDMs maintain lower perplexity and competitive entropy, reflecting robust parallel sample quality (Shariatian et al., 20 Oct 2025).
- Reasoning Tasks (Countdown, Game of 24): In arithmetic reasoning, LDDMs with loopholing improve success rates (Countdown 4: 94.4% vs 86.5% for baseline; Game of 24: 63% vs 47%) (Jo et al., 22 Oct 2025).
- Human and GPT-4.1 Aligned Evaluation: Higher fluency and coherence scores, with substantial reductions in generative perplexity (Gen PPL) over both discrete diffusion and autoregressive baselines.
Improvements are traced to both reduction in "idle" denoising steps and more faithful propagation of semantic context via the deterministic latent channel.
7. Applications and Implications
FUJI-LDDMs are broadly applicable to non-autoregressive generation tasks where joint structure and fast decoding are essential:
- Parallel Language Generation: For applications such as translation, summarization, or story generation, FUJI-LDDMs’ parallel, joint-denoising architecture enables coherent outputs in reduced step budgets.
- Categorical Data Modeling: Symbolic music, structured code, and other high-cardinality data types benefit from the model's capacity to model global joint dependencies and diverse outputs.
- Multi-Modal and Hybrid Generative Models: The fusion of discrete and continuous diffusion channels positions FUJI-LDDMs as a template for hybrid and multi-modal generative tasks, where one may desire semantics to guide token-level synthesis.
The approach is efficient and scalable (with only moderate training overhead from self-conditioning), and it narrows the generative quality gap with autoregressive models—traditionally the benchmark for sequence fidelity—while offering the practical benefits of parallel, non-autoregressive decoding.
Summary Table: Key Elements of FUJI-LDDMs
| Aspect | FUJI-LDDMs Approach | Benefit |
|---|---|---|
| Denoising | Fully joint over tokens & latents | Captures cross-token dependencies |
| Latent Path | Deterministic, propagated per step | Preserves contextual information post-sampling |
| Loss Function | ELBO with discrete + latent terms | Balanced, efficient training |
| Training | Two-stage, self-conditioning | Stability and efficiency |
| Performance | Lower perplexity, improved fluency | Effective at low sampling budgets |
By introducing fully joint discrete and latent denoising, and deterministic context propagation, FUJI-LDDMs enable high-fidelity, parallel generation in categorical domains, achieving strong empirical results and opening a pathway for further refinements in non-autoregressive generative modeling (Shariatian et al., 20 Oct 2025, Jo et al., 22 Oct 2025).