Papers
Topics
Authors
Recent
Search
2000 character limit reached

Memoised Wake-Sleep (MWS)

Updated 13 March 2026
  • Memoised Wake-Sleep (MWS) is an approximate inference algorithm that reuses high-probability latent assignments to accelerate learning in complex generative models.
  • It builds on the canonical Wake-Sleep framework by integrating a memory mechanism that preserves effective latent configurations and reduces variance in inference.
  • The hybrid extension, HMWS, leverages importance sampling to handle continuous latent variables, yielding state-of-the-art results in neuro-symbolic domains.

Memoised Wake-Sleep (MWS) is an approximate inference and learning algorithm designed to enhance the training of probabilistic generative models, particularly those involving complex structured or program-like discrete latent spaces. Unlike classical Wake-Sleep or its reweighted variants, MWS accelerates learning and improves inference quality by explicitly memoising high-probability latent variable configurations for each data point and reusing them across training iterations. Hybrid Memoised Wake-Sleep (HMWS) further extends this paradigm to models with both discrete and continuous latent variables by leveraging importance sampling for the continuous component. The method has demonstrated state-of-the-art results in varied neuro-symbolic domains such as program induction, structured kernel discovery, and 3D scene understanding (Le et al., 2021, Hewitt et al., 2020).

1. Wake-Sleep Framework and the Motivation for Memoisation

The canonical Wake-Sleep (WS) algorithm and its successors, such as Reweighted Wake-Sleep (RWS) and VIMCO, approximate posterior inference in latent-variable models by alternating between learning a recognition (inference) network and fitting the generative model. In practice, for models with large or combinatorially structured discrete latent spaces (e.g., programs, segmentations, grammars), inference via sampling from a recognition network may be extremely inefficient: high-probability configurations are rarely proposed, and expensive computation is duplicated across iterations (Hewitt et al., 2020). MWS addresses this by maintaining, for each data point, a fixed-size memory of the highest-probability latent assignments discovered so far, thus preserving and reusing valuable "particles" across the course of learning.

2. Mathematical Formulation of Memoised Wake-Sleep

Consider a generative model with parameters θ\theta and discrete latent variables zdz_d. For each data point xx, MWS maintains a memory {zdm}m=1M\{z_d^m\}_{m=1}^M of MM unique discrete latent configurations. The corresponding variational posterior is defined as:

qmem(zdx)=m=1Mωmδzdm(zd),q_{\rm mem}(z_d \mid x) = \sum_{m=1}^M \omega_m \, \delta_{z_d^m}(z_d),

where δzdm\delta_{z_d^m} is a point mass, and ωmpθ(zdm,x)\omega_m \propto p_\theta(z_d^m, x). This forms a mixture over memorised latent assignments, focusing the variational support on the top-scoring regions of posterior mass.

At each Wake phase, NN new proposals zdqϕ(zdx)z_d' \sim q_\phi(z_d \mid x) are drawn, merged into the candidate set, and the top zdz_d0 retained by their joint probability zdz_d1. The memory weights zdz_d2 are then normalised. In the Sleep phase, the generative model zdz_d3 and recognition model zdz_d4 are updated using gradients computed over the memorised support, supplying efficient, high-quality samples for both inference and learning (Le et al., 2021, Hewitt et al., 2020).

3. Extension to Hybrid Discrete–Continuous Models (HMWS)

Pure MWS is inapplicable to continuous latent variables since every continuous sample is almost surely unique and lacks reusability. Hybrid Memoised Wake-Sleep (HMWS) overcomes this by factorising the recognition model:

zdz_d5

where zdz_d6 is discrete (memoised) and zdz_d7 is continuous. To evaluate the marginal joint zdz_d8, HMWS employs importance sampling. For each zdz_d9, it draws xx0 samples xx1 and computes weights

xx2

The empirical mean xx3 serves as an unbiased estimator for ranking and memory updates.

Learning in HMWS comprises four signals: (a) memory update via IS, (b) generative model gradient via replay over memory and IS samples, (c) discrete recognition update using KL terms over the memory, and (d) continuous recognition update via IS-weighted KL minimization. An optional Sleep-Fantasy step, sampling from the model itself, may be added with mixing factor xx4 (Le et al., 2021).

4. Algorithmic Procedure and Complexity

The workflow for a data point xx5 in HMWS is summarised as:

  1. Wake: Propose xx6 new xx7 candidates; for each, draw xx8 xx9 samples, compute importance weights, and retain the top {zdm}m=1M\{z_d^m\}_{m=1}^M0 {zdm}m=1M\{z_d^m\}_{m=1}^M1 based on estimated {zdm}m=1M\{z_d^m\}_{m=1}^M2.
  2. Sleep-Replay: Update generative, discrete recognition, and continuous recognition parameters using gradients over the memory and their IS samples, as defined above.
  3. (Optional) Sleep-Fantasy: Sample from the model joint, further training the recognition model.

Per iteration, the computational complexity is dominated by {zdm}m=1M\{z_d^m\}_{m=1}^M3 joint evaluations (for proposals and replay samples), with all other costs (sorting, gradient steps) being lower order (Le et al., 2021).

Phase Main Computation Dominant Cost per Data Point
Wake {zdm}m=1M\{z_d^m\}_{m=1}^M4 proposals × {zdm}m=1M\{z_d^m\}_{m=1}^M5 IS samples {zdm}m=1M\{z_d^m\}_{m=1}^M6
Sleep-Replay {zdm}m=1M\{z_d^m\}_{m=1}^M7 memory entries × {zdm}m=1M\{z_d^m\}_{m=1}^M8 IS samples {zdm}m=1M\{z_d^m\}_{m=1}^M9
Total MM0

5. Empirical Performance in Neuro-Symbolic Domains

Memoised Wake-Sleep and its hybrid extension have been validated across complex generative tasks:

  • Structured GP-Kernel Learning: Grammar-driven kernel compositions (SE, WN, PerMM1, MM2, MM3) with LSTM priors (discrete) and kernel hyperparameters (continuous). HMWS outperforms RWS and VIMCO, converging faster and achieving higher IWAE-100 log marginal likelihood (Le et al., 2021).
  • 3D Compositional Scene Understanding: Placement of blocks with discrete primitive types and continuous offsets in grid-based scenes, rendered with differentiable rasterization. HMWS learns faster, attains higher likelihood, and produces interpretable posteriors capturing occlusion uncertainty relative to baselines (Le et al., 2021).
  • Program Induction Domains (MWS): Programmatic clustering, handwritten character decomposition, few-shot string concept learning, and symbolic cellular automata inference. MWS consistently demonstrates sharper convergence and superior inference/likelihood metrics compared to RWS and VIMCO, even with fewer proposals per iteration (Hewitt et al., 2020).

6. Relation to Prior Inference Methods and Distinctive Advantages

In classical Wake-Sleep, each iteration begins 'tabula rasa': the recognition model MM4 proposes new candidates anew, which is ineffective when MM5 is immature or posterior mass is highly concentrated. RWS and VIMCO address proposal weakness via multiple samples per iteration, but do not reuse successful proposals, increasing computational demands while offering no cross-iteration amortization.

MWS's core innovation is instance-wise "top-M memory": high-probability latents per datum are preserved and recycled, both improving the variational bound and providing reliable replay samples for update steps. This results in:

  • Lower variance and higher-quality inference.
  • Reduced number of recognition network (proposal) calls.
  • Accelerated convergence in marginal likelihood and posterior mass.
  • Directly optimised tight finite-support bounds, with memory size and computational tradeoffs fully transparent by design (Hewitt et al., 2020, Le et al., 2021).

A plausible implication is that memory-based mechanisms generalise naturally to hybrid settings and could benefit other classes of amortised inference, particularly in models where structure discovery and compositionality are central.

7. Summary and Prospects

Memoised Wake-Sleep redefines Wake-Sleep learning for challenging discrete and hybrid latent-variable models by preserving high-quality latent assignments per-datum and combining them with efficient recognition network proposals. The hybrid extension enables scalable, accurate inference where both structure and continuous variation matter. Across all tested neuro-symbolic domains, MWS/HMWS achieves significant gains in empirical learning speed, likelihood, and posterior quality over prior state-of-the-art methods, validating its value for future research on compositional and programmatic generative modelling (Le et al., 2021, Hewitt et al., 2020).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (2)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Memoised Wake-Sleep (MWS).