Papers
Topics
Authors
Recent
Search
2000 character limit reached

Wan-VAE: Compact 3D Causal Video VAE

Updated 4 March 2026
  • The paper introduces Wan-VAE, a compact 3D causal VAE that achieves high-fidelity video reconstruction by leveraging a three-term loss and strict temporal causality.
  • It employs innovative techniques such as a feature-cache mechanism, RMSNorm-based causal convolutions, and 2D-to-3D inflation for efficient training and inference.
  • This model underpins scalable video foundation systems, enhancing performance in industrial and academic applications through accelerated convergence and open-source reproducibility.

Wan Video Foundation Model (Wan-VAE) is a compact, highly efficient 3D causal variational autoencoder (VAE) developed as a core component of the Wan suite of video foundation models. It is architected to provide state-of-the-art video representation and reconstruction capabilities, effectively integrating with large-scale diffusion-transformer backbones. Key innovations include strict temporal causality, a feature‐cache mechanism for efficient streaming, a three-term training objective, and scalable pretraining via “inflate-and-fine-tune” from 2D to 3D. At 127 million parameters, Wan-VAE achieves leading speed and fidelity benchmarks on high-resolution, long-clip video reconstruction, fostering both industrial and academic applications through full open-sourcing (Wan et al., 26 Mar 2025).

1. Architectural Design

Wan-VAE employs an encoder–decoder topology specifically optimized for high-fidelity video compression and efficient integration with diffusion-transformer (DiT) models. The core architectural elements are as follows:

  • Input/Output Representation: Video tensor VR(1+T)×H×W×3V \in \mathbb{R}^{(1+T) \times H \times W \times 3}, supporting variable spatial (H×WH \times W) and temporal ($1+T$) dimensions.
  • Latent Space: zR(1+T/4)×(H/8)×(W/8)×16z \in \mathbb{R}^{(1+T/4) \times (H/8) \times (W/8) \times 16}, achieving 4× temporal compression (all but the first frame) and 8×8 spatial compression. Channel count C=16C=16 maximizes efficiency while maintaining expressiveness.
  • Encoder: A hierarchy of 3D causal convolutional residual blocks:
    • 3×3×3 causal convolutions prevent information “leakage” from future frames.
    • RMSNorm replaces GroupNorm, supporting strict autoregressivity and the feature‐cache.
    • SiLU activations.
    • Downsampling: Strided convs enable 2× temporal or spatial downsampling per block, cumulative factors 4× (temporal) and 8×8 (spatial) across four stages.
  • Decoder: Mirror-symmetric stack of 3D transposed-convolutions (or upsample + conv) reversing compression, including skip connections from encoder at matching resolutions.
  • Diffusion Transformer Interface: The output latent x=zx = z is patchified by a 3D conv (kernel 1,2,2), flattened into a token sequence for DiT (L=(1+T/4)(H/16)(W/16)L = (1+T/4) \cdot (H/16) \cdot (W/16)). DiT output is unpatchified and decoded to the pixel grid.
Component Structure Key Features
Encoder 3D causal conv residual blocks RMSNorm, SiLU, strided downsample
Decoder 3D transposed-conv/upsample+conv Causal, RMSNorm, skip connections
Latent Space (1+T/4)×(H/8)×(W/8)×16(1+T/4)\times(H/8)\times(W/8)\times16 Small channel count, 4×/8×8 compression
Parameters 127\sim127 million Compact architecture

2. Mathematical Formulation and Objective

Training employs a β-VAE style objective, augmented with perceptual and adversarial losses to enhance perceptual quality and realism. The detailed loss formulation:

  • Posterior: qφ(zx1)=N(μφ(x1),σφ2(x1))q_{\varphi}(z|x_1) = \mathcal{N}(\mu_{\varphi}(x_1), \sigma^2_{\varphi}(x_1))
  • Prior: p(z)=N(0,I)p(z) = \mathcal{N}(0,I)
  • Decoder: pθ(x1z)p_{\theta}(x_1|z) is either Laplace (L1{L1}) or Gaussian (L2{L2})
  • ELBO per sample:

LELBO(φ,θ;x1)=Eqφ(zx1)[logpθ(x1z)]+βKL[qφ(zx1)p(z)]\mathcal{L}_{ELBO}(\varphi, \theta; x_1) = \mathbb{E}_{q_{\varphi}(z|x_1)}\big[-\log p_{\theta}(x_1|z)\big] + \beta \cdot KL\big[q_{\varphi}(z|x_1) \| p(z)\big]

Approximated as:

x1x^1+βKL(N(μ,σ2)N(0,I))\|x_1 - \hat{x}\|_1 + \beta \cdot KL(\mathcal{N}(\mu, \sigma^2) \| \mathcal{N}(0, I))

  • Three-Term Loss:

LVAE=λrL1(x1,x^)+λKLKL(qφ(zx1)p(z))+λpLPIPS(x1,x^)\mathcal{L}_{VAE} = \lambda_r \cdot L1(x_1,\hat{x}) + \lambda_{KL} \cdot KL(q_{\varphi}(z|x_1) \| p(z)) + \lambda_p \cdot LPIPS(x_1,\hat{x})

with λr=3.0\lambda_r=3.0, λKL=3×106\lambda_{KL}=3\times10^{-6}, λp=3.0\lambda_p=3.0

Ltotal=LVAE+λGLGAN\mathcal{L}_{total} = \mathcal{L}_{VAE} + \lambda_G \cdot \mathcal{L}_{GAN}

where λG1×102\lambda_G\approx 1\times 10^{-2}.

LPIPS improves perceptual similarity. The GAN discriminator is 3D, promoting naturalistic video textures and dynamics.

3. Training Procedure and Data Pipeline

A multi-stage training pipeline is adopted, emphasizing scalable progression from efficient 2D initialization to high-fidelity 3D video understanding:

  1. Stage 1: Train a 2D-variant of the architecture as a VAE on billions of 128×128 images. Only 2D convolutions are used.
  2. Inflation: Weights from the 2D-VAE are “inflated” to 3D by kernel replication along the temporal axis, following the approach of MagViT-v2, initializing the video encoder/decoder.
  3. Stage 2: Train on short, low-res clips (5-frame, 128×128 videos) using the three-term loss. The small frame count accelerates convergence.
  4. Stage 3 (Fine-Tuning): Use higher-quality data with resolutions up to 720×720, clips up to 25 frames, and introduce the adversarial (GAN) loss.

Hardware and Throughput:

  • Training: Mixed-precision bf16 on A100 GPU clusters.
  • Inference: Feature cache enables chunked (4-frame) encoding on a 40 GB GPU for 720×720×25-frame videos, delivering 2.5× speedup vs HunYuanVideo VAE and sub-0.5s per clip encoding.
  • Convergence: 200K steps (2D), 100K (3D low-res), 50K (fine-tune high-res).

4. Innovations and Comparison to Prior Art

Wan-VAE introduces significant advances distinct from preceding video VAEs:

  • Causal 3D Convolutions with RMSNorm: Ensures strict temporal causality, permitting a feature‐cache mechanism suitable for unbounded-length streaming. Only two past frame features require caching for kernel size 3, or one for stride-2 downsampling.
  • Compact Latent Dimension: C=16 channels is smaller than 4–64 used in prior work, maintaining reconstruction quality while minimizing computational overhead.
  • 2D-to-3D Inflation: Replicating 2D image VAE weights as 3D kernels dramatically improves convergence during video VAE training.
  • Performance: At 720×720 and 25 frames, Wan-VAE achieves the highest PSNR and qualitative fidelity (texture, face, text, high-motion) among open and commercial video VAEs, outperforming HunYuanVideo VAE and CogVideoX VAE in both speed (2.5× faster) and perceptual quality (Wan et al., 26 Mar 2025).

5. Feature-Cache Mechanism and Causality

The feature‐cache mechanism leverages the strict causal structure of 3D convolutions, allowing efficient sequential (streaming) encoding and decoding of video with minimal temporal memory footprint:

  • Streaming Support: For a 3×1×1 temporal kernel, only the two most recent frame features are cached. For stride-2 convs, only the most recent frame is needed.
  • Autoregressive Compatibility: RMSNorm avoids information “smearing” that would arise from GroupNorm, preserving strict autoregressivity and enabling efficient, deterministic streaming over unbounded video sequences.
  • Industrial and Research Relevance: This structure considerably reduces memory overhead for deployment, enabling chunked inference on consumer- or server-grade GPUs.

6. Implementation Details and Practical Considerations

A high-level training pseudocode encompasses the inflate-from-2D initialization, stagewise optimization, and integration of the adversarial stage:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
initialize 2D-VAE encoder ϕ_2D, decoder θ_2D
train ϕ_2D, θ_2D on images with L = 3·L1 + 3e6·KL + 3·LPIPS

inflate ϕ  inflate3D(ϕ_2D), θ  inflate3D(θ_2D)
for stage in [low-res short clip, high-res long clip]:
    for each batch of video V  ℝ^{B×(1+T)×H×W×3}:
        z ~ q_ϕ(z|V)            # sample latents
        V̂ = p_θ(z)             # decode
        rec_loss = L1(V, V̂)    # pixel L1
        lp_loss  = LPIPS(V, V̂) # perceptual
        kl_loss  = KL(q_ϕ(z|V) || N(0,I))
        if stage==fine_tune:
            gan_loss = GANDisc(V̂, V)
        else:
            gan_loss = 0
        total_loss = 3·rec_loss + 3·lp_loss + 3e6·kl_loss + λ_G·gan_loss
        backpropagate(total_loss)
        update(ϕ, θ) via AdamW(lr=1e4, wd=1e3)

  • Hyperparameters: AdamW optimizer (lr 1e41e^{-4}, wd 1e31e^{-3}), mixed-precision bf16, batch sizes of 512 (images), 64 (short videos), 16 (long high-res videos) per GPU.
  • Convergence Profile: Stage 1: ~200K steps; Stage 2: ~100K; Stage 3: ~50K.

7. Role within the Wan Video Foundation Model Suite

Wan-VAE forms the critical video representation backbone within the broader Wan architecture. Its compressed latent representations are designed for direct manipulation by the diffusion-transformer (DiT) model, enabling efficient scaling to model sizes up to 14B parameters and robust handling of diverse video generation tasks. The combined system supports not only vanilla video synthesis but also covers image-to-video, video editing, and personal video generation across eight downstream tasks. Full open-sourcing of the models and code aims to foster reproducibility and further advances in both industry and academia (Wan et al., 26 Mar 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Wan Video Foundation Model (Wan-VAE).