Multi-Head cVAE for Disentangled Generation
- Multi-Head cVAE is a generative model that decomposes the latent space into label-relevant and label-irrelevant codes for precise control.
- The architecture uses SPADE for spatial (label-specific) information and AdaIN for unsupervised style modulation in image generation.
- It achieves effective disentanglement, demonstrated by quantitative metrics on datasets like 3D-Chair and FaceScrub, enhancing controllability in tasks such as identity swapping.
A Multi-Head Conditional Variational Autoencoder (cVAE) is a generative model architecture designed to decompose the latent representation into orthogonal components: a label-relevant code capturing structured, controllable information, and a label-irrelevant code capturing complementary, unsupervised variation. In the context of "Disentangling the Spatial Structure and Style in Conditional VAE" (Zhang et al., 2019), this is achieved via a dual-headed latent space where one head encodes spatial structure or style associated with class labels, and the other head encodes class-independent factors. Each head is injected into the decoder with dedicated adaptive normalization mechanisms, enabling effective disentanglement of spatial structure and style in image generation.
1. Model Architecture
The multi-head cVAE consists of three major modules: a label-condition mapping network generating a label-relevant code , an encoder producing a label-irrelevant latent code , and a decoder conditioned on both and at every upsampling layer.
- Label-Condition Mapping ():
- Input: a one-hot label .
- Architecture: a multi-layer perceptron (3–4 fully-connected layers, width ≈ 512).
- Output: embedding , which may be shaped as a spatial map (if carries spatial information, e.g., pose) or a vector 0 (for categorical labels). Practical choices include 1 and 2 for 3 images (so 4 shape is 5).
- Encoder (6):
- Input: image 7 (concatenated with label maps if necessary).
- Architecture: strided convolutional blocks downsampling to either a 8 vector (style-posterior) or spatial map 9 (structure-posterior).
- Latent outputs: mean 0 and std 1 parameterizing 2.
- Decoder:
2. Probabilistic Framework
Let 5 denote the deterministic, label-relevant (label "head") code, and 6 denote the stochastic, label-irrelevant (uncorrelated "head") code.
- Priors:
- 7 (isotropic Gaussian).
- 8 (deterministic).
- Posteriors:
- 9
- 0
- Sampling:
- 1, 2 (reparameterization)
- 3 (deterministic)
- ELBO Objective:
4
Given 5, the last term vanishes. The likelihood is implemented as an 6 or 7 image reconstruction loss.
- Adversarial Learning:
- Uses a cGAN-style hinge loss to sharpen outputs. A discriminator 8 distinguishes between real and generated data, including cases with permuted labels or random 9.
3. Adaptive Normalization in Decoding
At each decoder layer 0, spatial and style codes modulate the activations via two normalization modules:
- SPADE (label-relevant 1):
- Produces 2, 3 with spatial dimensions, via a small convolutional network upsampling 4 to 5.
- AdaIN (label-irrelevant 6):
- Produces channel-wise 7, 8 using an MLP applied to 9 and broadcast spatially.
Given pre-activation 0, normalization proceeds: 1 The output features 2 and 3 are concatenated along the channel dimension and projected via a 4 convolution to restore channel size.
4. Implementation Configurations and Ablations
Key practical choices and architectural variants:
| Variant | 5 injection | 6 injection |
|---|---|---|
| S1 | AdaIN | concat-input |
| S2 | SPADE | concat-input |
| S3 | AdaIN | AdaIN |
| S4 | AdaIN | SPADE |
| Proposed | SPADE | AdaIN |
- Both 7 and 8 are dimensioned to 256, yielding in the structure code case 9 of 0 (for 1 images) and 2 as 3, or vice versa for style code scenarios.
- Encoder and decoder convolutional blocks follow channel progression 4.
- Default datasets: 3D-Chair (5), FaceScrub (6).
- Optimizer: Adam; learning rate and batch size are not fixed in the paper but typical settings are used (e.g., 7, 8, 9).
5. Quantitative and Qualitative Performance
Performance of the proposed disentangling design is demonstrated via experiments on 3D-Chair (label captures azimuth/viewpoint) and FaceScrub datasets (label as identity):
- 3D-Chair:
- Mutual Information 0 (lower is better; indicates improved disentanglement).
- ResNet-50 classification accuracy at target azimuth: 1.
- FaceScrub:
- Identity-classification accuracy: 2.
- Fréchet Inception Distance (FID): 3.
Qualitative results show that reconstructed or generated samples can enforce a target identity or viewpoint while preserving complementary factors such as style, pose, or expression.
6. Significance and Context
This design cleanly separates label-associated (structured) factors from unsupervised (residual) variation. By employing SPADE and AdaIN at every decoder stage—feeding the label-relevant and label-irrelevant codes respectively—conditional cVAE generation becomes modular, controllable, and suited for tasks demanding disentanglement. The approach enables, for example, faithful identity swapping in faces or view manipulation in 3D objects, where label information may or may not carry spatial meaning. This separation of signal pathways, along with the adversarial sharpness constraint, is shown to outperform simpler approaches in terms of both disentanglement metrics and visual fidelity (Zhang et al., 2019).