This paper introduces a novel approach for autoregressive image generation that operates directly on continuous-valued tokens, eliminating the need for vector quantization (VQ). The core innovation is "Diffusion Loss," a method that models the per-token probability distribution using a denoising diffusion process. This allows autoregressive models, traditionally reliant on discrete token spaces and categorical cross-entropy loss, to leverage higher-quality continuous-valued tokenizers.
Key Contributions:
- Diffusion Loss: A new loss function that models the probability distribution for a continuous-valued token conditioned on a vector (produced by the autoregressive model). This is achieved by training a small denoising MLP jointly with the autoregressive model.
- Elimination of Vector Quantization: By using Diffusion Loss, the method sidesteps the challenges associated with training VQ tokenizers, such as gradient approximation issues and reconstruction quality limitations.
- Generalized Autoregressive Framework: The paper unifies standard autoregressive (AR) models and masked generative models into a Masked Autoregressive (MAR) framework. MAR predicts multiple tokens simultaneously in a randomized order, leveraging bidirectional attention for improved token communication.
- Strong Performance and Efficiency: The proposed MAR models with Diffusion Loss achieve state-of-the-art results on ImageNet 256x256, with a strong FID score (< 2.0) and fast generation speeds (< 0.3 seconds per image).
Rethinking Discrete-Valued Tokens
The paper argues that the critical aspect for autoregressive modeling is not the discrete nature of tokens, but the ability to:
- Define a loss function to measure the difference between estimated and true per-token distributions.
- Implement a sampler to draw samples from this distribution during inference.
While discrete tokens conveniently use categorical distributions and cross-entropy loss, Diffusion Loss provides an alternative for continuous tokens.
Diffusion Loss Mechanism
For a continuous-valued ground-truth token and a conditioning vector from the autoregressive model, Diffusion Loss models as follows:
- Loss Function: Based on denoising score matching, the loss is:
Where:
- is random noise.
- is a diffusion timestep.
- is the noised token.
- is a small MLP (denoising network) that predicts the noise given , , and . The gradient from this loss updates both the denoising MLP and the autoregressive model producing . To improve loss utilization, is sampled multiple times (e.g., 4 times) for each .
- Sampler: At inference, tokens are sampled using a reverse diffusion procedure:
starting from to get . and is the noise level.
- Temperature Sampling: To control sample diversity, a temperature is introduced by scaling the noise term in the sampler by . This is analogous to temperature in categorical sampling.
Autoregressive Models with Diffusion Loss
The standard autoregressive formulation is adapted.
- The autoregressive network (e.g., a Transformer) produces a conditioning vector .
- Diffusion Loss is applied to model .
Unifying Autoregressive and Masked Generative Models (MAR)
The paper proposes a Masked Autoregressive (MAR) model that generalizes AR and masked generation:
- Bidirectional Attention for Autoregression: Unlike traditional causal attention, bidirectional attention (similar to MAE) can be used. Known tokens attend to each other, and unknown tokens attend to all known tokens. Loss is computed only on unknown tokens.
- Random Order Autoregression: The model processes randomly permuted sequences, with positional embeddings informing the decoder about the original positions to predict.
- Masked Autoregressive (MAR): This model predicts a set of tokens at each step , based on previously known/generated sets .
MAR uses a fully randomized order for token prediction, unlike MaskGIT/MAGE which use confidence-based ordering.
Implementation Details
- Tokenizer: Publicly available continuous-valued tokenizers from LDM (e.g., KL-16, which uses KL divergence regularization) are primarily used. VQ-16 (a VQ-GAN) is used for comparison. The method can also adapt to tokenizers with mismatched strides by grouping tokens.
- Diffusion Process: Follows DDPM/IDDPM, with a cosine noise schedule (1000 steps at training, resampled to e.g., 100 at inference). The denoising network predicts noise . Optional term can be included. Classifier-Free Guidance (CFG) is supported.
- Denoising MLP: A small MLP with residual blocks (e.g., 3 blocks, 1024 channels wide). The conditioning vector is added to the time embedding of and used in AdaLN layers of the MLP.
- Transformer Architecture: Based on ViT. A default "Large" (L) size has 32 blocks, 1024 width (~400M parameters).
- AR Baseline: GPT-style causal attention with triangular masking and kv-caching.
- MAR Model:
- Training: Randomly masks tokens (e.g., 70-100% unknown). Pads with 64 [cls] tokens for encoder stability. Encoder and decoder have equal blocks (e.g., 16 each for MAR-L).
- Inference: Progressively reduces masking ratio from 1.0 to 0 over a cosine schedule (e.g., 64 steps). Temperature sampling is applied.
- Training: AdamW optimizer, batch size 2048, LR 8e-4 with 100-epoch warmup, then constant LR for Diffusion Loss models. EMA of model parameters is maintained.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
class DiffusionLoss(nn.Module): def __init__(self, denoising_mlp_depth, denoising_mlp_width): super().__init__() # denoising_mlp predicts noise given x_t, timestep, and condition z self.denoising_mlp = SimpleMLP(denoising_mlp_depth, denoising_mlp_width) # diffusion_process handles noising (q_sample) and denoising step (p_sample) self.diffusion_process = GaussianDiffusion() # (num_timesteps, noise_schedule, etc.) def forward(self, z_condition, x_true_token): # z_condition is output from the main autoregressive model # x_true_token is the ground truth continuous token batch_size = x_true_token.size(0) # 1. Sample random noise noise = torch.randn_like(x_true_token) # 2. Sample random timesteps timesteps = torch.randint(0, self.diffusion_process.num_timesteps, (batch_size,), device=x_true_token.device) # 3. Create noised token x_t x_t = self.diffusion_process.q_sample(x_start=x_true_token, t=timesteps, noise=noise) # 4. Predict noise using the denoising MLP, conditioned on z_condition predicted_noise = self.denoising_mlp(x_t, timesteps, z_condition) # 5. Calculate L2 loss between actual noise and predicted noise loss = F.mse_loss(predicted_noise, noise) # Optional: add L_vlb for improved sample quality # The gradient of this loss will flow back to z_condition, # thereby training the autoregressive model that produces z_condition. return loss def sample(self, z_condition, num_inference_steps, temperature=1.0): # Sample a token x given the condition z_condition # Start from pure noise x_t = torch.randn_like(z_condition) # Shape should match token dim # Reschedule timesteps for inference if different from training inference_timesteps = self.diffusion_process.resample_timesteps(num_inference_steps) for t in reversed(inference_timesteps): # p_sample performs one reverse diffusion step # It uses self.denoising_mlp internally, conditioned on z_condition # Temperature is applied by scaling noise in p_sample x_t = self.diffusion_process.p_sample( denoising_fn=self.denoising_mlp, x=x_t, t=torch.full((z_condition.size(0),), t, device=z_condition.device), condition=z_condition, temperature=temperature ) return x_t # This is the sampled token x_0 |
Experimental Results
- Diffusion Loss vs. Cross-Entropy: Consistently outperforms cross-entropy with discrete VQ-tokens across AR and MAR variants. For MAR, Diffusion Loss reduces FID by 50-60% relatively (e.g., Table 1: MAR default w/ CFG, CrossEnt FID 3.69 vs. DiffLoss FID 1.98).
- Flexibility with Tokenizers:
- Works with VQ tokenizers by using the continuous latent before quantization (VQ-16 + DiffLoss FID 7.82 w/o CFG vs. VQ-16 + CrossEnt FID 8.79).
- Continuous KL-16 (rFID 1.43) significantly better than VQ-16 (rFID 5.87), leading to better generation (FID 3.50 vs 7.82).
- Handles mismatched strides (e.g., KL-8 tokenizer with 32x32 tokens, grouped to 16x16 for the generator, achieves 2.05 FID).
- Denoising MLP: A small MLP (21M params for 1024 width) is effective and adds little computational overhead (~5% params, ~10% inference time).
- Diffusion Sampling Steps: 100 diffusion steps at inference are generally sufficient for good quality.
- Temperature: Plays a crucial role in controlling diversity/fidelity, similar to discrete models.
- AR to MAR:
- Random order AR > Raster order AR (FID 19.23 -> 13.07 w/o CFG).
- Bidirectional attention MAR > Causal attention AR (FID 13.07 -> 3.43 w/o CFG).
- Predicting multiple tokens per step in MAR (default 64 steps) offers a good speed-quality trade-off.
- Speed/Accuracy: MAR with Diffusion Loss shows a favorable trade-off compared to standard AR (even with kv-cache) and DiT models. Can generate an image in <0.3s with FID <2.0 on ImageNet 256x256 (MAR-L).
- System-Level Comparison (ImageNet 256x256, 800 epochs):
- MAR-L (479M params): FID 1.78 (w/ CFG), IS 296.0.
- MAR-H (943M params): FID 1.55 (w/ CFG), IS 303.7.
- These results are competitive with or surpass leading diffusion and autoregressive models.
- ImageNet 512x512: MAR-L achieves FID 1.73 (w/ CFG), competitive with other SOTA models.
Practical Implications and Implementation Considerations
- Training Cost: Training a MAR-L model for 400 epochs took ~2.6 days on 16x8 V100 GPUs, reportedly faster than training DiT-XL/2 (4.6 days) or LDM-4 (9.5 days) for the same epochs on the same cluster.
- Inference Speed: MAR models are fast due to parallel token prediction. The diffusion process per token is handled by a small MLP, making it efficient.
- Tokenizer Choice: The quality of the continuous tokenizer is important. Using higher-fidelity continuous tokenizers (lower reconstruction FID) directly translates to better generative performance.
- CFG and Temperature: These are key hyperparameters to tune for optimal sample quality and diversity, similar to other generative models. CFG involves running the model twice per step (conditional and unconditional) for the denoising MLP's input.
- Model Architecture:
- The main autoregressive Transformer can be a standard ViT-like architecture.
- The denoising MLP for Diffusion Loss is relatively small. Its size (depth, width) can be tuned; larger MLPs can improve quality slightly at a minor compute cost.
- Limitations:
- Can still produce artifacts, common to many generative models.
- Performance is dependent on the pre-trained tokenizer quality.
- Primarily tested on ImageNet; further validation on diverse datasets is needed.
Conclusion
The paper demonstrates that autoregressive image generation does not require discrete vector-quantized tokens. By modeling per-token distributions with a diffusion process (Diffusion Loss), models can operate on continuous token spaces, benefiting from better tokenizers and achieving strong generative performance with efficient inference. This opens avenues for applying autoregressive models in other continuous-valued domains.