Memoised Wake-Sleep (MWS)
- Memoised Wake-Sleep (MWS) is an approximate inference algorithm that reuses high-probability latent assignments to accelerate learning in complex generative models.
- It builds on the canonical Wake-Sleep framework by integrating a memory mechanism that preserves effective latent configurations and reduces variance in inference.
- The hybrid extension, HMWS, leverages importance sampling to handle continuous latent variables, yielding state-of-the-art results in neuro-symbolic domains.
Memoised Wake-Sleep (MWS) is an approximate inference and learning algorithm designed to enhance the training of probabilistic generative models, particularly those involving complex structured or program-like discrete latent spaces. Unlike classical Wake-Sleep or its reweighted variants, MWS accelerates learning and improves inference quality by explicitly memoising high-probability latent variable configurations for each data point and reusing them across training iterations. Hybrid Memoised Wake-Sleep (HMWS) further extends this paradigm to models with both discrete and continuous latent variables by leveraging importance sampling for the continuous component. The method has demonstrated state-of-the-art results in varied neuro-symbolic domains such as program induction, structured kernel discovery, and 3D scene understanding (Le et al., 2021, Hewitt et al., 2020).
1. Wake-Sleep Framework and the Motivation for Memoisation
The canonical Wake-Sleep (WS) algorithm and its successors, such as Reweighted Wake-Sleep (RWS) and VIMCO, approximate posterior inference in latent-variable models by alternating between learning a recognition (inference) network and fitting the generative model. In practice, for models with large or combinatorially structured discrete latent spaces (e.g., programs, segmentations, grammars), inference via sampling from a recognition network may be extremely inefficient: high-probability configurations are rarely proposed, and expensive computation is duplicated across iterations (Hewitt et al., 2020). MWS addresses this by maintaining, for each data point, a fixed-size memory of the highest-probability latent assignments discovered so far, thus preserving and reusing valuable "particles" across the course of learning.
2. Mathematical Formulation of Memoised Wake-Sleep
Consider a generative model with parameters and discrete latent variables . For each data point , MWS maintains a memory of unique discrete latent configurations. The corresponding variational posterior is defined as:
where is a point mass, and . This forms a mixture over memorised latent assignments, focusing the variational support on the top-scoring regions of posterior mass.
At each Wake phase, new proposals are drawn, merged into the candidate set, and the top 0 retained by their joint probability 1. The memory weights 2 are then normalised. In the Sleep phase, the generative model 3 and recognition model 4 are updated using gradients computed over the memorised support, supplying efficient, high-quality samples for both inference and learning (Le et al., 2021, Hewitt et al., 2020).
3. Extension to Hybrid Discrete–Continuous Models (HMWS)
Pure MWS is inapplicable to continuous latent variables since every continuous sample is almost surely unique and lacks reusability. Hybrid Memoised Wake-Sleep (HMWS) overcomes this by factorising the recognition model:
5
where 6 is discrete (memoised) and 7 is continuous. To evaluate the marginal joint 8, HMWS employs importance sampling. For each 9, it draws 0 samples 1 and computes weights
2
The empirical mean 3 serves as an unbiased estimator for ranking and memory updates.
Learning in HMWS comprises four signals: (a) memory update via IS, (b) generative model gradient via replay over memory and IS samples, (c) discrete recognition update using KL terms over the memory, and (d) continuous recognition update via IS-weighted KL minimization. An optional Sleep-Fantasy step, sampling from the model itself, may be added with mixing factor 4 (Le et al., 2021).
4. Algorithmic Procedure and Complexity
The workflow for a data point 5 in HMWS is summarised as:
- Wake: Propose 6 new 7 candidates; for each, draw 8 9 samples, compute importance weights, and retain the top 0 1 based on estimated 2.
- Sleep-Replay: Update generative, discrete recognition, and continuous recognition parameters using gradients over the memory and their IS samples, as defined above.
- (Optional) Sleep-Fantasy: Sample from the model joint, further training the recognition model.
Per iteration, the computational complexity is dominated by 3 joint evaluations (for proposals and replay samples), with all other costs (sorting, gradient steps) being lower order (Le et al., 2021).
| Phase | Main Computation | Dominant Cost per Data Point |
|---|---|---|
| Wake | 4 proposals × 5 IS samples | 6 |
| Sleep-Replay | 7 memory entries × 8 IS samples | 9 |
| Total | – | 0 |
5. Empirical Performance in Neuro-Symbolic Domains
Memoised Wake-Sleep and its hybrid extension have been validated across complex generative tasks:
- Structured GP-Kernel Learning: Grammar-driven kernel compositions (SE, WN, Per1, 2, 3) with LSTM priors (discrete) and kernel hyperparameters (continuous). HMWS outperforms RWS and VIMCO, converging faster and achieving higher IWAE-100 log marginal likelihood (Le et al., 2021).
- 3D Compositional Scene Understanding: Placement of blocks with discrete primitive types and continuous offsets in grid-based scenes, rendered with differentiable rasterization. HMWS learns faster, attains higher likelihood, and produces interpretable posteriors capturing occlusion uncertainty relative to baselines (Le et al., 2021).
- Program Induction Domains (MWS): Programmatic clustering, handwritten character decomposition, few-shot string concept learning, and symbolic cellular automata inference. MWS consistently demonstrates sharper convergence and superior inference/likelihood metrics compared to RWS and VIMCO, even with fewer proposals per iteration (Hewitt et al., 2020).
6. Relation to Prior Inference Methods and Distinctive Advantages
In classical Wake-Sleep, each iteration begins 'tabula rasa': the recognition model 4 proposes new candidates anew, which is ineffective when 5 is immature or posterior mass is highly concentrated. RWS and VIMCO address proposal weakness via multiple samples per iteration, but do not reuse successful proposals, increasing computational demands while offering no cross-iteration amortization.
MWS's core innovation is instance-wise "top-M memory": high-probability latents per datum are preserved and recycled, both improving the variational bound and providing reliable replay samples for update steps. This results in:
- Lower variance and higher-quality inference.
- Reduced number of recognition network (proposal) calls.
- Accelerated convergence in marginal likelihood and posterior mass.
- Directly optimised tight finite-support bounds, with memory size and computational tradeoffs fully transparent by design (Hewitt et al., 2020, Le et al., 2021).
A plausible implication is that memory-based mechanisms generalise naturally to hybrid settings and could benefit other classes of amortised inference, particularly in models where structure discovery and compositionality are central.
7. Summary and Prospects
Memoised Wake-Sleep redefines Wake-Sleep learning for challenging discrete and hybrid latent-variable models by preserving high-quality latent assignments per-datum and combining them with efficient recognition network proposals. The hybrid extension enables scalable, accurate inference where both structure and continuous variation matter. Across all tested neuro-symbolic domains, MWS/HMWS achieves significant gains in empirical learning speed, likelihood, and posterior quality over prior state-of-the-art methods, validating its value for future research on compositional and programmatic generative modelling (Le et al., 2021, Hewitt et al., 2020).