- The paper introduces MMaDA, a unified diffusion model that employs a shared transformer architecture with discrete tokenization for both text and image data.
- The paper leverages mixed long Chain-of-Thought finetuning to align reasoning across modalities, enhancing performance in textual understanding and generation.
- The paper integrates a novel reinforcement learning strategy, UniGRPO, with structured noising to achieve competitive results in multimodal reasoning and image generation tasks.
Here is a detailed summary of the paper "MMaDA: Multimodal Large Diffusion LLMs" (2505.15809):
The paper introduces MMaDA, a novel multimodal diffusion foundation model designed to unify and achieve strong performance across diverse tasks including textual reasoning, multimodal understanding, and text-to-image generation. Existing multimodal models often rely on autoregressive architectures or separate models for different modalities, particularly struggling with post-training for diffusion-based models in non-autoregressive settings. MMaDA addresses these limitations through a unified diffusion architecture and novel post-training strategies.
MMaDA's approach is built upon three key innovations:
- Unified Diffusion Architecture and Objective: MMaDA employs a single diffusion-based transformer architecture with a shared probabilistic formulation. This modality-agnostic design eliminates the need for separate components for text and vision. To achieve this, it uses a consistent discrete tokenization strategy for both modalities. Text is tokenized using the LLaDA tokenizer, and images are converted into discrete tokens using a pretrained image quantizer based on MAGVIT-v2 (with a downsampling factor of 16 and codebook size of 8192). The model is pretrained using a unified mask token prediction objective, optimizing a cross-entropy loss computed only on masked tokens across both image and text data. This allows the model to predict masked tokens (x0i) given a noised version (xt) of the input sequence (x0) at different noise levels (t).
The unified objective function is:
$\mathcal{L}_{\text{unify}(\theta) = - \mathbb{E}_{t, x_0, x_t} \left[\frac{1}{t} \sum_{ i = 1 }^L I[x_t^i = [MASK]] \log p_{\theta}(x_0^i|x_t) \right]$
where I[⋅] is the indicator function for masked tokens.
- Mixed Long Chain-of-Thought (CoT) Finetuning: To enable effective post-training, especially for reasoning-intensive tasks, MMaDA utilizes a mixed long CoT fine-tuning strategy. A unified CoT format is curated across textual and visual domains:
|<special_token>|<reasoning_process>|<special_token>|<result>
. This format aligns reasoning processes across modalities, facilitating knowledge transfer. High-quality, long-form CoT samples are generated using open-source LLM/VLMs and filtered by verifiers. During fine-tuning, the model is trained to reconstruct masked tokens in the result segment (rt) conditioned on the original prompt (p0) and the corrupted result, using the same mask token prediction objective.
The objective for mixed-task finetuning is:
$\mathcal{L}_{\text{Mixed-SFT} = - \mathbb{E}_{t, p_0, r_0, r_t} \left[\frac{1}{t} \sum_{i=1}^{L'} I[r_t^i = [MASK]] \log p_{\theta}(r_0^i | p_0, r_t) \right]$
This stage serves as a crucial cold-start for the subsequent reinforcement learning stage.
- Unified Reinforcement Learning (UniGRPO): The paper proposes UniGRPO, a policy-gradient-based RL algorithm specifically adapted for diffusion foundation models. It overcomes challenges of adapting AR-based GRPO (like DeepSeek-Math (2402.03300)) to diffusion models, such as local masking dependency and the absence of an autoregressive chain rule for sequence likelihoods. UniGRPO introduces a structured noising strategy during RL, sampling a random masking ratio uniformly for the response segment in each gradient step. This exposes the model to various denoising stages, leveraging the multi-step nature of diffusion. The sequence-level log-likelihood is approximated by averaging over masked tokens.
The UniGRPO objective integrates a clipped surrogate reward and KL regularization:
$\mathcal{J}_\text{UniGRPO}(\theta) = \mathbb{E}_{(q,a)\sim \mathcal{D}, \{o_i\}_{i=1}^G\sim \pi_{\theta_\text{old}(\cdot\mid q),\{p_i\in [0,1]\}_{i=1}^G}} \Bigg[ \frac{1}{G}\sum_{i=1}^{G} \frac{1}{|o_i|}\sum_{t=1}^{|o_i|} \Bigg( \min \Big( r_{i,t}^{\prime}(\theta) \hat{A}_{i,t}, \quad \quad \text{clip} \Big( r_{i,t}^{\prime}(\theta), 1 - \varepsilon, 1 + \varepsilon \Big) \hat{A}_{i,t} \Big) - \beta D_{\text{KL}(\pi^{\prime}_{\theta} || \pi^{\prime}_{\text{ref}) \Bigg) \Bigg]$
UniGRPO utilizes diversified reward modeling tailored for different tasks:
- Textual Reasoning: Composite reward (Correctness + Format).
- Multimodal Reasoning: Correctness, Format, and a scaled CLIP Reward for caption-based tasks.
- Text-to-Image Generation: Scaled CLIP Reward and ImageReward (human preference score).
Implementation Details and Experiments:
MMaDA is initialized from LLaDA-8B-Instruct pretrained weights and Show-o image tokenizer weights. The training proceeds in three stages:
- Stage 1 (Foundational Pretraining): 200K steps on general text and multimodal data (RefinedWeb, ImageNet, diverse image-text datasets), followed by 400K steps with more diverse image-text pairs.
- Stage 2 (Instruction Tuning & CoT Finetuning): 50K steps using textual (Alpaca) and visual (LLaVA-1.5) instruction tuning data, combined with the curated Mixed Long-CoT reasoning data.
- Stage 3 (UniGRPO Training): 50K steps using Reinforcement Learning Data derived from mathematical and logical datasets.
Training was performed on 64 A100 (80GB) GPUs.
Flexible Sampling:
MMaDA supports different sampling strategies at inference time.
- Text Generation: Uses a semi-autoregressive denoising strategy (based on LLaDA), where the sequence is divided into blocks, and tokens within each block are denoised iteratively based on confidence. This yields more detailed text compared to fixed-length generation.
- Image Generation: Employs parallel non-autoregressive sampling with a low-confidence remasking strategy and cosine noise schedule, consistent with MAGVIT-v2. Classifier-free guidance is used. The entire image token sequence is treated as one block for parallel generation.
Experimental Results:
MMaDA-8B demonstrates strong generalization across tasks:
- Multimodal Understanding: Achieves comparable or superior performance to dedicated understanding models (LLaVA-v1.5, InstructBLIP, Qwen-VL-Chat) and consistently outperforms prior unified models (SEED-X, DreamLLM, Show-o, Janus) on benchmarks like POPE, VQAv2, GQA, MMMU, MMB, and SEED.
- Text-to-Image Generation: Outperforms both generation-only models (SDXL, LlamaGen) and unified models on CLIP Score, ImageReward, and various GenEval metrics, particularly excelling on world knowledge-aware generation (WISE Cultural benchmark).
- Textual Reasoning: Shows competitive performance with strong AR baselines (Qwen2-7B, LLaMA3-8B) on general LLM benchmarks (MMLU, ARC-C, TruthfulQA) and significantly outperforms LLaDA-8B on math benchmarks (GSM8K, MATH, GPQA), demonstrating the viability of diffusion models as general-purpose LLMs.
Analysis and Conclusion:
Ablation studies confirm the effectiveness of both Mixed Long-CoT finetuning (boosting reasoning) and UniGRPO (further improving understanding, reasoning, and generation metrics). Analysis of UniGRPO's design choices shows that the structured noising strategy and unmasked questions lead to better reward trends and training stability compared to alternative diffusion RL approaches like d1. The paper observes a clear synergy across the three tasks during joint training, where improvements in one area (e.g., textual reasoning) contribute to others (e.g., factual accuracy in image generation). Diffusion models like MMaDA also offer sampling efficiency advantages over AR models by enabling parallel generation steps. Furthermore, MMaDA naturally supports inpainting capabilities across text, visual Q&A, and images due to its mask token prediction objective, demonstrating flexibility and generalization.
In conclusion, MMaDA is presented as a pioneering unified diffusion foundation model that successfully integrates multimodal understanding and generation with strong reasoning capabilities. The proposed post-training methods, Mixed Long-CoT finetuning and UniGRPO, are crucial for its performance, bridging the gap between pretraining and post-training for diffusion-based MLLMs. While the current model is 8B parameters, the results suggest the potential for further improvement with scaling.