- The paper introduces SPNs as a novel deep probabilistic framework that achieves tractable, exact inference by representing complex distributions using sum and product nodes.
- The paper demonstrates an efficient two-pass algorithm for computing marginal probabilities and MPE, significantly reducing computational complexity compared to traditional graphical models.
- The paper validates SPNs through image completion experiments, showing that integrating hard EM and pruning techniques leads to faster training and improved accuracy.
This paper introduces Sum-Product Networks (SPNs), a novel deep probabilistic architecture designed to overcome the limitations of traditional graphical models, particularly the computational complexity of inference and learning associated with the partition function. SPNs offer a way to represent complex probability distributions where inference remains tractable.
Core Concepts of SPNs:
- Structure: An SPN is a rooted Directed Acyclic Graph (DAG). The leaves of the graph represent indicator variables for the state of random variables (e.g., xi for Xi=1 and xˉi for Xi=0 in the Boolean case). Internal nodes are either weighted sum nodes or product nodes.
- Evaluation: The value of a product node is the product of the values of its children. The value of a sum node i is the weighted sum ∑j∈Ch(i)wijvj, where Ch(i) are the children of i, vj is the value of child j, and wij≥0 is the weight of the edge from i to j. The value of the SPN is the value computed at its root node.
- Network Polynomial: SPNs compute a polynomial function of the indicator variables. This polynomial, when evaluated under specific indicator settings, can yield probabilities or marginals.
- Validity: An SPN is considered valid if evaluating it with evidence e (setting corresponding indicators to 1 and others to 0) directly yields the unnormalized probability ΦS(e)=∑x∼eS(x). The paper provides sufficient conditions for validity:
- Completeness: All children of the same sum node must have the same scope (i.e., involve the same set of variables).
- Consistency: A variable cannot appear negated as an indicator leaf input to one child of a product node and non-negated as input to another child of the same product node.
- Tractability: If an SPN is valid, the partition function ZS is simply the value of the SPN when all indicators are set to 1 (S(∗)). Computing ZS or any marginal probability P(e)=S(e)/S(∗) takes time linear in the size (number of edges) of the SPN. Theorem 2 states that if a distribution is representable by a polynomial-sized valid SPN, its partition function is tractable.
- Decomposability: A stricter condition where the children of a product node have disjoint scopes. SPNs only require consistency, making them more general than models requiring decomposability (like arithmetic circuits, PCFGs, thin junction trees).
SPNs vs. Other Models:
- Graphical Models: SPNs can represent some distributions (like uniform over even parity states) more compactly than traditional graphical models or mixture models. They naturally capture context-specific independence.
- Deep Architectures (DBNs, DBMs): Unlike DBNs/DBMs which typically rely on approximate inference (like Gibbs sampling), valid SPNs allow for exact and efficient inference. SPNs explicitly model sums (mixtures) and products (features), whereas DBNs/DBMs often focus on feature hierarchies and approximate the sums.
- Convolutional Networks: SPNs can be seen as a probabilistic generalization, with sum operations analogous to average-pooling and max operations (for MPE) analogous to max-pooling.
- Arithmetic Circuits/AND-OR Graphs: SPNs add model semantics and learning algorithms to these related inference compilation structures.
Inference in SPNs:
- Marginal Probabilities: Can be computed efficiently using a two-pass algorithm (similar to backpropagation). An upward pass computes the value of each node Si(e). A downward pass computes derivatives ∂S(e)/∂Si(e). Marginals for indicator variables P(Xi=t∣e) and latent mixture variables P(Yk=j∣e) can be derived from these values. Time complexity is linear in SPN size.
- Most Probable Explanation (MPE): Can be computed by replacing sum operations with max operations in the upward pass and tracing back the maximizing choices in the downward pass. This is exact for decomposable SPNs and extends to consistent SPNs.
Learning SPNs:
The paper proposes learning both structure and parameters, often starting with a dense, valid initial structure and then refining it.
- Structure Initialization (
GenerateDenseSPN
): Create an initial valid SPN. One strategy is to define nodes corresponding to subsets of variables (e.g., rectangular regions in an image) and create sum/product nodes based on ways to partition these subsets. Random selection of subsets/partitions is also possible.
- Weight Learning (
UpdateWeights
):
- Gradient Descent: Use the efficient derivative computation (from marginal inference) to perform gradient ascent on log-likelihood. Requires projection/renormalization to keep weights valid (summing to 1 for children of sum nodes if desired). Prone to vanishing gradients in deep networks.
- EM (Soft EM): Treat sum nodes as latent variables. The E-step computes posterior probabilities of latent variables (marginals P(Yk=j∣e)) using the inference algorithm. The M-step updates weights based on expected counts. Also suffers from diffusion issues.
- Hard EM: Proposed as a solution to vanishing gradients/diffusion. Uses MPE inference instead of marginal inference in the E-step to find the single most likely configuration of latent variables. Updates counts only for the "winning" child of each sum node. M-step normalizes counts to get weights. This allows learning much deeper SPNs effectively.
- Pruning: After learning, edges with zero weight are pruned, simplifying the network.
Implementation Considerations & Experiments:
- Architecture Example (Images): Use nodes for rectangular regions, decompose regions into subregions (potentially multi-resolution).
- Continuous Variables: Handle by replacing leaf indicators with outputs of probability density functions (e.g., univariate Gaussians or mixtures). Sum nodes become integrals (though in practice, for evidence X=x, the node value is p(x); otherwise it's 1). The experiments used Gaussian mixtures per pixel.
- Learning Details: Used online hard EM with mini-batches, add-one smoothing, and an L0 prior (possible with hard EM) for sparsity.
- Task: Image completion on Caltech-101 and Olivetti faces (occluding half the image).
- Comparison: Compared against DBNs, DBMs, PCA, and Nearest Neighbor.
- Results: SPNs significantly outperformed alternatives in terms of Mean Squared Error (MSE) on the completion task. They were also substantially faster to train (hours vs. days/weeks for DBNs/DBMs) and perform inference (sub-second exact inference vs. slow approximate inference). SPNs required less hyperparameter tuning and data preprocessing. The learned SPNs were very deep (36 layers). SPNs also showed strong performance on preliminary classification tasks.
Conclusion:
SPNs provide a powerful and tractable deep probabilistic modeling framework. Their key advantage lies in enabling efficient exact inference, which in turn facilitates more effective and faster learning compared to contemporary deep generative models like DBNs and DBMs, especially for deep structures using the proposed hard EM algorithm. The experiments demonstrated significant practical advantages in terms of speed, accuracy, and ease of use for tasks like image completion.