Einsum Networks: Scalable Probabilistic Circuits
- Einsum Networks are efficient probabilistic circuits restructured into layered, vectorized einsum operations to support scalable learning in high-dimensional domains.
- They replace numerous small sum/product operations with batched high-dimensional tensor contractions, yielding significant computational speedups and memory savings.
- EiNets leverage automatic differentiation for streamlined EM updates, achieving 10x–100x training and inference acceleration over traditional probabilistic circuit implementations.
Einsum Networks (EiNets) are an efficient architectural and algorithmic design for tractable probabilistic circuits (PCs), leveraging large-scale tensor contraction operations (“einsum”) to achieve significant computational and memory savings. Their core contribution lies in restructuring the inherently sparse computation graphs of PCs into layered, vectorized einsum operations optimized for modern hardware, thus enabling scalable learning and tractable inference in high-dimensional domains such as real-world image modeling (Peharz et al., 2020).
1. Architecture of Probabilistic Circuits and EiNet Layerization
Probabilistic circuits are directed acyclic graphs encoding a density function over random variables through a combination of sum nodes (representing convex mixtures) and product nodes (factoring densities across disjoint subsets), with leaves parameterizing tractable probability distributions (e.g., Gaussians or categoricals). Two critical constraints—decomposability (products act on disjoint variable subsets) and smoothness (sums act over children with identical variable scopes)—guarantee exact linear-time inference.
Conventional PC implementations generate highly sparse and branched computation graphs. Deep learning backends struggle with the overheads from launching myriad small sum and product kernels, particularly given the prevalence of log-domain arithmetic operations. EiNets mitigate this bottleneck by vectorizing every sum and leaf node with parallel components and flattening computational graph depth into layers—alternating between vectorized products and sums. The essence of the EiNet approach is to replace sequences of small sum/product operations with batched high-dimensional einsum contractions, allowing efficient utilization of GPU/CPU resources.
A topological layer sorting algorithm groups all leaf, sum, and product nodes into ordered alternating layers, enabling each sum layer to be computed via a single einsum tensor contraction.
2. Mathematical Foundations of the Einsum Operation
At its core, EiNet’s einsum operation reshapes node computations into high-dimensional tensor arithmetic:
- For a sum node of vector size , with product child aggregating two -vector children , the basic computation (pre-log) is:
Here, is a non-negative tensor with normalized sums over its last two dimensions.
- All evaluation is performed in log-domain to preserve numerical stability. The output of the einsum layer applies the log-einsum-exp trick:
where , .
- For an entire sum layer over sums, this generalizes to:
with subsequent log-einsum-exp applied row-wise.
Sums with more than two children are decomposed into chains of “simple sum” nodes with one child each, followed by an element-wise mixing sum—both realized as einsum contractions.
3. Expectation-Maximization by Automatic Differentiation
Classical EM for PCs involves two phases:
- E-step: Compute expected sufficient statistics and latent counts:
- M-step: Update sum-node weights and leaf parameters:
EiNets leverage automatic differentiation on the globally constructed log-likelihood. A single backward pass yields all necessary E-step statistics for sum and leaf nodes without custom gradient code. M-step updates are performed by accumulating gradients across mini-batches; with step size , stochastic online EM proceeds as:
This stochastic protocol is equivalent to SGD with implicit natural gradient updates over the joint distribution .
4. Computational and Memory Complexity Analysis
EiNets optimize both computational and memory complexity:
- Per sum node einsum: $2K$ exponentials, logarithms, multiply and add operations. Only $3K$ floats need storing (, , ), with no explicit intermediate product allocation.
- Layer-wise einsum: For sum nodes, total ops occur within one BLAS-style tensor contraction.
Contrasting prior implementations, which require additional storage and more small-kernel launches, EiNet’s reduction to large einsum operations achieves:
- Training speedups: – faster than LibSPN/SPFlow across randomized binary tree benchmarks.
- Memory savings: $5$– less peak usage for large by avoiding explicit product allocations.
- Inference speedups: – acceleration for forward passes per sample.
5. Empirical Performance on Image Modeling
EiNets enable tractable generative modeling on large-scale datasets previously inaccessible to PCs:
- SVHN ( RGB digits): 581,000 train, 23,000 valid, 26,000 test. Data clustered via -means, each cluster modeled by a partition-decomposition (PD) EiNet tree.
- CelebA ( RGB faces): 183,000 train, 10,000 valid, 10,000 test. Similar PD mixture approach.
For each cluster, EiNets use for sums/leaves, with Gaussian leaf parameters and strict diagonal covariance range enforcement (). Stochastic EM is run for 25 epochs per cluster (mini-batch size 500, EM step-size $0.5$), requiring hr (SVHN) and hr (CelebA) on NVIDIA Tesla P100.
- Outputs: SVHN sample generations are sharp and plausible; CelebA generations yield recognizable albeit over-smoothed faces. Conditional reconstructions (inpainting) are visually coherent. Log-likelihood steadily improves with stable convergence, though GAN-based models remain qualitatively superior in sample fidelity.
6. Advantages, Limitations, and Extensions
Advantages:
- Achieve up to two orders-of-magnitude faster training and inference compared to prior PC libraries.
- Realize or greater memory savings for large vectorization.
- Automatic EM implementation via standard autodiff eliminates the need for custom backward code.
- Tractable inference is preserved for arbitrary marginals and conditionals, and scale extends to images of and pixels—domains infeasible for legacy exact PC methods.
Limitations:
- Sample quality underperforms deep generative models (e.g., GANs, Flows).
- Partition-decomposition structures can cause artificial artifacts (“stripy” effects).
- All sums and leaves use fixed global vector size (though heterogeneous is plausible).
- Demonstrations restricted to axis-aligned partitions; richer decomposition strategies untested.
Possible Extensions:
- Structure learning with data-driven or graph-based partitions.
- Hybridizing with deterministic/selective nodes for MPE and constraint propagation.
- Incorporation of non-axis-aligned or adaptive splits for complex data.
- Extension to richer exponential family leaves, including mixed continuous-discrete families.
- Integration with modern generative pipelines (normalizing flows, autoregressive models).
- Enhanced GPU kernel exploitation; lower precision arithmetic.
- Theoretical expansion of log-einsum-exp for deterministic structured inference.
In summary, Einsum Networks reconstitute tractable probabilistic circuits as layered, monolithic tensor contraction engines that integrate seamlessly with deep learning hardware and autodiff frameworks, marking a decisive advance in the scalable learning of expressive, inference-tractable generative models (Peharz et al., 2020).