Hierarchical Group-Gumbel-Max Decomposition
- The paper introduces a recursive extension of the classical Gumbel-Max trick to enable exact sampling of structured discrete objects.
- It employs group-wise decomposition and reweighting strategies to partition combinatorial domains, facilitating hierarchical subset selection.
- The method offers unbiased score-function gradient estimation with advanced variance reduction techniques for training deep architectures.
Hierarchical Group-Gumbel-Max Decomposition is a probabilistic and algorithmic scheme that generalizes the classical Gumbel-Max trick to efficiently sample structured discrete objects from complex domains via a recursive, group-wise selection procedure. It provides an efficient mechanism for exact sampling, unbiased score-function gradient estimation, and hierarchical subset selection, notably in structured latent variable models and geometric deep learning architectures (Struminsky et al., 2021, Yang et al., 2019).
1. Classical Gumbel-Max and Exponential-Min Trick
The foundational mechanism is the Gumbel-Max (or equivalently, Exponential-Min) trick. Given a finite set of "keys" and associated nonnegative rates , independently sampling and selecting
produces as a draw from the categorical distribution
Equivalently, for , is distributed as , and
This establishes a randomized selection mechanism, pivotal for categorical sampling, and underlies subset and structure sampling in more complex domains (Struminsky et al., 2021).
2. Recursive Extension and Stochastic Invariant
Hierarchical Group-Gumbel-Max Decomposition generalizes the above mechanism by proceeding recursively over subdivided groups of variables. At each recursion step:
- The current key set 0 is partitioned into disjoint groups 1 (2).
- Within each 3, 4 is drawn, and 5 is subtracted from each 6 in 7.
- The surviving keys and state 8 are updated, and the process recurses.
A crucial property, the "stochastic invariant," guarantees that, conditioned on the selection trace, all remaining exponentials remain independent with corresponding rates (possibly truncated to 9). This enables exact likelihood and gradient computations throughout the recursion (Struminsky et al., 2021).
3. Hierarchical, Group-wise Decomposition in Structured Spaces
The hierarchical scheme partitions the combinatorial domain by recursively grouping, selecting, and reweighting:
- Coarse groups 0 are selected, sampling the "best" element in each.
- The survivors 1 define a reduced subproblem at a finer scale.
- At each group and level, selection is via group-wise arg-mins (Exponential-Min), and the final structure 2 gathers all selections through a combining function.
This architecture captures distributions over complex objects, such as top-3 subsets, permutations (Plackett–Luce), spanning trees (Kruskal or Chu–Liu–Edmonds), and binary trees, with the property that the selection process’s joint density decomposes over recursion steps and groups (Struminsky et al., 2021, Yang et al., 2019).
4. Gumbel-Softmax Relaxation and Training
Direct differentiation through hard selection is intractable. During training, the Gumbel-Softmax relaxation provides a continuous approximation. For subset selection, parallel Gumbel-Softmax draws are performed, and for 4 picks over 5 items, a learnable linear layer 6 produces 7, with Gumbel noise 8 added: 9 The output 0 enables standard backpropagation, since all operations are differentiable (Yang et al., 2019).
At inference, discrete samples are produced by direct Gumbel-Max top-1 selection on each row.
5. Score-Function Gradient Estimation and Variance Reduction
Gradient estimation leverages the factorized structure and recursive trace. Three unbiased REINFORCE-type estimators are possible:
- 2-REINFORCE: Uses the full exponential samples, with high variance.
- 3-REINFORCE: Marginalizes over the trace variables, with reduced variance.
- 4-REINFORCE: Marginalizes to the output variable, yielding further variance reduction though computing 5 is often intractable.
Variance is further reduced with:
- Conditional reparameterization control variates (e.g., RELAX-type):
6
- Multi-sample leave-one-out baselines:
7
These strategies are unbiased and exploit the Markovian structure for practical, low-variance gradients (Struminsky et al., 2021).
6. Hierarchical Applications and Architectural Integration
The group-wise, hierarchical Gumbel-based schemes pervade several domains:
- Point set and geometric data: Gumbel Subset Sampling (GSS) applies a hierarchical sequence of subset samplers, downsampling points and refining representations in transformer-based networks for point clouds. Each sampling layer applies a Group-Gumbel-Max subset selection, with hierarchical stages interleaved with permutation-equivariant attention modules. At test time, hard subset selection is realized via the Gumbel-Max trick (Yang et al., 2019).
- Combinatorial structures: Hierarchical group-wise decomposition enables direct sampling and model training on permutations, matchings, trees, and other structures, with each step recapitulating a combinatorial construction (e.g., Kruskal's or Chu–Liu–Edmonds for MSTs) (Struminsky et al., 2021).
7. Computational Cost and Theoretical Guarantees
The time complexity per recursion is 8 in the worst case, reducible to 9 or 0 using data structures for special cases (e.g., union-find for trees). The log-probability of the sampling trace is explicitly computable: 1 All estimator variants discussed are strictly unbiased for 2, and the variance satisfies
3
by the Rao-Blackwell and Jensen inequalities (Struminsky et al., 2021).
This decomposition and its relaxations allow efficient, unbiased, and variancereduced learning of models with structured discrete latent variables, without introducing additional constraints on model smoothness, and extend to hierarchical, structure-preserving deep architectures for sets and combinatorial objects (Struminsky et al., 2021, Yang et al., 2019).