DeepDG Codebase for Discrete Data Modeling
- DeepDG Codebase is a modular framework implementing the Deep Directed Generative Autoencoder to model discrete data like binarized MNIST with deterministic encoders and probabilistic decoders.
- The architecture employs likelihood decomposition and a straight-through estimator to efficiently optimize binary encoding despite non-differentiability.
- It features greedy stacking of shallow DGAs and robust training protocols, yielding competitive log-likelihoods and quality generative samples.
The DeepDG codebase implements the Deep Directed Generative Autoencoder (DGA) framework for modeling discrete data, such as binarized MNIST digits, using stacked deterministic encoders and probabilistic decoders. The key methodological innovation is an exact rewriting of the likelihood for discrete data as , where is a deterministic discrete function realized by a deep neural network. This design enables learning encoders that map data to binary codes whose distribution is much simpler and concentrated, “flattening the manifold” compared to the original input space. The DeepDG architecture further incorporates training protocols robust to the non-differentiability induced by hard binary encoding, yielding competitive log-likelihood results and plausible samples on standard datasets (Ozair et al., 2014).
1. Probabilistic Objective and Likelihood Decomposition
For discrete input data (e.g., 784-dimensional binary vectors from MNIST), DeepDG defines a deterministic encoder , with , and a decoder , with a simple prior applied to the codes. The log-likelihood for any datum is rewritten as:
- Encoder : A deep neural network that produces pre-activations , discretized as .
- Decoder : A factorial Bernoulli distribution where each parameter is the sigmoid output of a deep network fed by , yielding probabilities .
- Prior : A factorized Bernoulli prior with parameters , typically estimated via frequency counting or running averages over observed codes.
The corresponding training objective maximizes the lower bound:
Both terms are implemented in practice as cross-entropy losses—one for data reconstruction, one as a regularization over the encoded activations.
2. Architectural Specifications
Shallow DGA
A shallow DGA consists of a single encoder-decoder-prior triplet:
- Encoder: Input layer of 784 units (for MNIST), followed by 1–3 fully connected tanh-activated hidden layers. Example layerings include:
- 1-hidden: 784 → 500 (tanh) →
- 2-hidden: 784 → 1000 (tanh) → 500 (tanh) →
- 3-hidden: 784 → 2000 → 1000 → 500 (all tanh), outputting
- Discretization: , yielding a binary code.
- Decoder: Mirrors the encoder, with symmetric hidden sizes and tanh activations, mapping the binary code to a 784-dimensional output via logits followed by a sigmoid. Optional injection of 1% salt-and-pepper noise (bit flipping at ) during training can improve robustness.
Deep DGA
A deep DGA stacks shallow DGAs, trained greedily:
- Stage 1 maps .
- Stage 2 maps .
- … up to Stage , mapping .
Each stage uses its own encoder, decoder, and prior, with code dimension .
3. Training Procedures and Optimization
Loss Function and Annealing
For a given stage , the loss is:
The hyperparameter governs prior regularization and is annealed from 0 (pure autoencoder) to 1 over epochs, preventing premature collapse of the encoder output.
Straight-Through Estimator
The binary nature of precludes traditional gradient computation. Gradients are instead approximated by:
- Computing as if were continuous.
- Assigning this pseudo-gradient to the pre-activation: .
- Backpropagating this surrogate through the encoder network, enabling effective optimization.
Prior Fitting and Hyperparameters
Priors are updated via moving average or empirical frequency after each minibatch or epoch. Default hyperparameters include batch size 100, learning rate search over (with halving on validation loss deterioration), 50–100 epochs per stage, Xavier/Glorot initialization, and zero biases. No momentum or L1/L2 regularization is applied.
Greedy Stacking and Fine-Tuning
- Train Stage 1 DGA with .
- Freeze the encoder, compute the codes .
- Train Stage 2 on , with its own schedule.
- Repeat for all stages.
- Optionally fine-tune the entire stack end-to-end with all .
4. Algorithmic Structure and Sampling
Below is a high-level summary of the principal algorithmic steps for one stage, following the canonical pseudocode from (Ozair et al., 2014):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
for epoch in range(E): beta = beta_schedule(epoch) for batch in data_loader: if stage > 1: z = prev_encoder(batch) else: z = batch a = Encoder(z) h = (a > 0).float() update_prior(p_prior, h) p_hat = Decoder(h) recon_loss = cross_entropy(z, p_hat) prior_loss = cross_entropy(h, p_prior) loss = recon_loss + beta * prior_loss # Decoder gradient: standard backprop decoder_opt.zero_grad() loss.backward(retain_graph=True) decoder_opt.step() # Encoder using straight-through estimator grad_h = compute_grad_as_if_continuous(loss, h) grad_a = grad_h # straight-through encoder_opt.zero_grad() backprop_through_encoder(grad_a) encoder_opt.step() if validation_cost_increased: reduce_learning_rate() |
Ancestral Sampling: Sampling from a trained, stacked Deep DGA proceeds top-down:
- Sample .
- For : decode one stage down via , then sample .
- The final output is the generated sample.
5. Codebase Organization and Software Components
The DeepDG reference codebase is modularized as follows:
| File | Content / Classes | Purpose |
|---|---|---|
| config.py | Hyperparameters, dimensions, schedules | Central configuration |
| models.py | Encoder, Decoder, ShallowDGA, DeepDGA (PyTorch nn.Modules) | Architecture definitions |
| trainer.py | Trainer class for pretrain/fine-tuning loops | Training orchestration |
| sampling.py | Ancestral sampling, log-likelihood estimation | Sampling, evaluation utilities |
| data.py | MNIST loading, binarization, batching | Data pipeline |
| utils.py | Running-average prior, checkpointing, logging | Auxiliary utilities |
| main.py | CLI for train/sample/eval | Entry point, script management |
Principal Classes and Responsibilities:
- Encoder(nn.Module): Layered MLP with tanh, outputs pre-activations.
- Decoder(nn.Module): MLP mapping code to output logits/sigmoid.
- ShallowDGA: Bundles encoder, decoder, and vector; supplies forward and update routines.
- DeepDGA: Manages a list of ShallowDGA stages, methods for stacked training, sampling, and likelihood estimation.
Dependencies: Python 3.6+, PyTorch 1.6+ (or TensorFlow 2.x), torchvision or tf-datasets, numpy, scipy, matplotlib, optional tqdm.
6. Empirical Performance, Usage, and Evaluation
To reproduce reported results on binarized MNIST:
- Prepare configuration with target dimensions, learning rate, and -schedule.
- Instantiate and pretrain each ShallowDGA sequentially, saving checkpoints and codes per stage.
- Optionally perform joint fine-tuning with all priors enabled ( for every stage).
- Generate samples via ancestral sampling; evaluate log-likelihoods through importance sampling estimates of the partition function.
Adhering to the prescribed training protocols, including use of the straight-through estimator, -annealing, and greedy stacking, yields a test log-likelihood of nats per digit on binarized MNIST and visually plausible generative samples (Ozair et al., 2014).
7. Context and Significance within Generative Modeling
The DeepDG model demonstrates an exact likelihood decomposition for discrete data and an effective, tractable path to learning generative autoencoders for high-dimensional and structured inputs. The stacking of shallow DGAs via greedy pretraining transforms the data distribution into forms more amenable to estimation with simple parametric priors, leveraging the “manifold flattening” property. The use of deterministic binary encodings and straight-through gradient estimators distinguishes DeepDG from most variational autoencoder approaches, enabling coherent ancestral generative sampling and empirical log-likelihood computation on benchmark datasets such as MNIST (Ozair et al., 2014).