Accelerating Diffusion Models with Parallel Sampling: Inference at Sub-Linear Time Complexity
(2405.15986v1)
Published 24 May 2024 in cs.LG, cs.DC, cs.NA, math.NA, and stat.ML
Abstract: Diffusion models have become a leading method for generative modeling of both image and scientific data. As these models are costly to train and evaluate, reducing the inference cost for diffusion models remains a major goal. Inspired by the recent empirical success in accelerating diffusion models via the parallel sampling technique~\cite{shih2024parallel}, we propose to divide the sampling process into $\mathcal{O}(1)$ blocks with parallelizable Picard iterations within each block. Rigorous theoretical analysis reveals that our algorithm achieves $\widetilde{\mathcal{O}}(\mathrm{poly} \log d)$ overall time complexity, marking the first implementation with provable sub-linear complexity w.r.t. the data dimension $d$. Our analysis is based on a generalized version of Girsanov's theorem and is compatible with both the SDE and probability flow ODE implementations. Our results shed light on the potential of fast and efficient sampling of high-dimensional data on fast-evolving modern large-memory GPU clusters.
The paper presents novel parallel sampling algorithms using Picard iterations to achieve sub-linear (O(log d)) inference time in diffusion models.
It details two methods—PIADM-SDE and PIADM-ODE—that leverage simultaneous NN evaluations to break the traditional sequential computation barrier.
The research reveals a trade-off between time and space complexity, requiring substantial memory to enable efficient high-dimensional parallel sampling.
This paper, "Accelerating Diffusion Models with Parallel Sampling: Inference at Sub-Linear Time Complexity" (Chen et al., 24 May 2024), introduces novel algorithms, PIADM-SDE and PIADM-ODE, designed to significantly reduce the inference time of diffusion models for high-dimensional data. The core problem addressed is the computational expense of generating samples from trained diffusion models, which traditionally requires a large number of sequential evaluations of a neural network-based score function. Previous theoretical work established bounds on this complexity, often scaling polynomially with the data dimension d, such as O(d) or O(d). This paper aims to break this polynomial barrier by leveraging parallel computation.
The key idea is to divide the total sampling process, spanning a time horizon T, into a relatively small number of blocks (specifically, O(logd) blocks). Within each block, the evolution of the diffusion process is framed as solving an ODE or SDE, which can be approximated using Picard iterations. A crucial property of Picard iterations is that each step of the iteration relies only on the output of the previous step across all time points within the block. This allows for parallel computation of the score function and state updates for all discretized time steps within that block simultaneously.
The paper presents two main algorithms:
PIADM-SDE: This algorithm applies the parallel sampling strategy directly to the backward Stochastic Differential Equation (SDE) formulation of diffusion models.
Implementation: The time horizon is divided into N blocks. Each block is further discretized into M steps. Within each block, K Picard iterations are performed. Each iteration k computes the state x(k+1) at all M time steps in the block based on the state x(k) at corresponding discrete time points. This computation across the M steps can be done in parallel. The algorithm uses an exponential integrator for numerical stability and a shrinking step size towards the data end to handle the singularity of the score function.
Theoretical Guarantee: Under standard assumptions on the data distribution and the learned score function, PIADM-SDE achieves an approximate time complexity of O(logd) to produce samples within a desired accuracy (measured by KL divergence). This is a significant improvement over the O(d) complexity of prior SDE-based methods.
Implementation Cost: The parallelization over M steps within each block requires storing intermediate states and performing NN evaluations for all M steps simultaneously. The analysis shows M needs to be O(d), leading to a space complexity of O(d2).
PIADM-ODE: This algorithm applies the parallel sampling strategy to the Probability Flow Ordinary Differential Equation (ODE) formulation. This is a deterministic process, which can offer computational advantages.
Implementation: Similar to PIADM-SDE, the time horizon is blocked and Picard iterations are used within blocks (this constitutes the "predictor" step). A key difference is the inclusion of a "corrector" step after each predictor block. This corrector step uses parallelized Underdamped Langevin Monte Carlo (ULMC) dynamics for a fixed, short duration. The ULMC is also parallelized using Picard-like iterations. This predictor-corrector structure aims to improve the quality of the samples and handle potential issues with the deterministic flow.
Theoretical Guarantee: PIADM-ODE also achieves an approximate time complexity of O(logd) (measured by Total Variation distance).
Implementation Cost: The ODE formulation and the structure of the corrector step allow for a smaller M parameter compared to PIADM-SDE. The analysis shows M and the corresponding parameter in the corrector step can be O(d). This reduces the space complexity to O(d1.5), an improvement over PIADM-SDE's space requirement.
The "approximate time complexity" is defined in the paper as the number of unparallelizable evaluations of the neural network-based score function. In the parallel setting, this corresponds to the number of sequential Picard iterations across blocks (N×K) multiplied by the number of sequential steps in the corrector (N†×K†) in the ODE case, where N is the number of outer blocks and K,N†,K† are iteration depths/block counts. The analysis shows that setting the number of outer blocks and iteration depths appropriately results in poly-logarithmic complexity in d.
The theoretical analysis relies on advanced concepts from stochastic calculus, including Girsanov's theorem, to carefully track the difference between the true backward process and the algorithm's output distribution under various metrics (KL divergence, Total Variation distance, 2-Wasserstein distance). The proofs involve bounding errors introduced by time discretization, the difference between the learned and true score functions, and the approximation introduced by the truncated Picard iterations. The exponential convergence property of Picard iterations (when the step size is small enough relative to Lipschitz constants) is key to bounding the error accumulated over iterations within a block.
Practical Implications:
The research provides a theoretical foundation for parallel sampling techniques, validating recent empirical successes like the ParaDiGMS algorithm.
The O(logd) time complexity suggests that sampling time for diffusion models can become much less dependent on the data dimension in theory, which is crucial for applications dealing with very high-dimensional data (e.g., high-resolution images, complex simulations).
The algorithms highlight a trade-off between time and space complexity. Achieving sub-linear time requires significant memory to perform computations in parallel across many steps/states simultaneously. This makes the algorithms particularly well-suited for modern GPU clusters with large memory capacities and high memory bandwidth.
The PIADM-ODE algorithm offers a better space complexity (O(d1.5)) compared to PIADM-SDE (O(d2)), which is a practical consideration for deploying these models.
Implementation Considerations:
Implementing these algorithms requires managing parallel execution of NN evaluations and state updates across many time points within a block. This necessitates efficient parallel programming techniques and potentially specialized hardware or libraries.
The memory requirement, especially for PIADM-SDE scaling as O(d2), could be a bottleneck for extremely high dimensions, potentially requiring distributed computing strategies or careful memory management.
The constants hidden within the O notation are important for practical performance. The theoretical analysis provides insights into how parameters like block size, step size, and iteration depth should be chosen, but fine-tuning would likely be necessary for real-world deployment.
The assumptions on the smoothness and accuracy of the learned score function (Assumptions 3.1' and 3.3 in the paper) are theoretical requirements. Ensuring these hold in practice for complex data distributions and large neural networks is an ongoing challenge in diffusion model training.
In conclusion, this paper offers a significant theoretical advancement by demonstrating how parallel sampling can achieve sub-linear inference time complexity for diffusion models, paving the way for potentially much faster sampling of high-dimensional data on capable hardware.