Token Caching for Diffusion Transformer Acceleration
Introduction
Diffusion models have established themselves as a cornerstone in the domain of generative modeling, excelling in tasks such as image, video, and text generation. Recently, Diffusion Transformers (DiTs) have emerged as a transformative approach, offering an alternative to the ubiquitous U-Net architectures in diffusion models. Despite their impressive generative capabilities, DiTs suffer from prohibitively high computational costs due to their multi-step iterative inference process and the quadratic complexity of their attention mechanisms. Addressing this critical bottleneck, the authors propose TokenCache, a novel post-training acceleration method designed specifically to enhance the efficiency of DiTs by caching and selectively pruning intermediate tokens.
Methodology
TokenCache is designed to minimize redundant computations among tokens across inference steps in DiTs. It tackles three key questions:
- Which tokens should be pruned to eliminate redundancy?
- Which blocks should be targeted for efficient pruning?
- At which time steps should caching be applied to balance speed and quality?
To address these questions, TokenCache introduces a Cache Predictor that assigns importance scores to tokens, enabling selective pruning without degrading model performance. Moreover, the method employs an adaptive block selection strategy to focus on blocks minimally impacting the network’s output. It also proposes a Two-Phase Round-Robin (TPRR) scheduling policy to optimize caching intervals throughout the denoising process, balancing acceleration with minimal quality loss.
Key Components
Cache Predictor:
The Cache Predictor is a small, learnable network that predicts the importance of each token. Importance scores are used to prune redundant tokens, selectively preventing those tokens from being updated in subsequent inference steps. During training, a superposition approach interpolates between pruning and non-pruning states, allowing the model to learn which tokens can be safely pruned while preserving generative quality.
Adaptive Block Selection:
Rather than uniformly applying token pruning across all blocks, TokenCache adaptively selects the least important blocks for pruning based on the aggregated importance scores of their tokens. This adaptive selection is guided by the Cache Predictor's token importance predictions and ensures that the most salient features are retained, further optimizing computational efficiency without compromising output fidelity.
Two-Phase Round-Robin (TPRR) Scheduling:
To determine the most effective timesteps for applying token caching, TPRR divides the inference process into two phases. The first phase uses a larger cache interval to exploit high token correlations early in the denoising process, while the second phase reduces the cache interval to maintain details in later steps. This dynamic scheduling strategy adapts to the evolving importance of tokens throughout the generation process.
Experimental Results
The authors validate the effectiveness of TokenCache through extensive experimentation on DiT and MDT architectures. Key metrics used for evaluation include FID, sFID, IS, precision, and recall. The results indicate that TokenCache consistently achieves a favorable trade-off between generation quality and inference speed. For instance, on MDT with 256×256 resolution, TokenCache achieves a speedup of 1.51× while maintaining comparable or superior image quality metrics relative to the full inference baseline.
Ablation Studies
The paper provides thorough ablation studies to validate the design choices in TokenCache. These include comparisons of different token pruning strategies, assessments of block selection methods, and evaluations of various timestep schedules. The findings show that:
- Grid-based token pruning via the Cache Predictor significantly outperforms both random and global learnable pruning strategies.
- Adaptive block selection is crucial for maintaining generative quality, outperforming random block selection consistently.
- The TPRR schedule, particularly when emphasizing the later phase of the diffusion process, provides a balanced approach to caching without compromising image fidelity.
Implications and Future Directions
TokenCache demonstrates a practical solution for accelerating DiTs, highlighting the importance of fine-grained caching strategies in diffusion models. Beyond the immediate computational benefits, TokenCache opens the door for further exploration into more sophisticated caching mechanisms and adaptive inference techniques. Future research can investigate fine-tuning the Cache Predictor using techniques like LoRA to enhance its adaptability across various generative tasks.
Conclusion
The introduction of TokenCache marks a significant step forward in optimizing diffusion transformers. By integrating token-level pruning, adaptive block selection, and dynamic timestep scheduling, TokenCache offers substantial computational savings while preserving high generative quality. This work not only lays the groundwork for more efficient diffusion models but also inspires new avenues for enhancing the scalability and applicability of generative AI systems.