miVAE: Multi-Modal Identifiable VAE
- The paper introduces miVAE, which unifies neural activity and visual stimuli by disentangling shared and modality-specific latent spaces.
- It employs a two-level disentanglement strategy with identifiable exponential-family priors and cross-modal losses for zero-shot alignment across subjects.
- Experimental validation shows high cross-modal alignment (Pearson-R > 0.9) and effective score-based attribution for both neural and stimulus features.
The multi-modal identifiable variational autoencoder (miVAE) is a generative modeling framework developed to unify and disentangle the latent structure underlying complex visual stimuli and simultaneously recorded neural responses in the primary visual cortex (V1), with explicit accommodation for cross-individual variability and modality-specific structure. By employing a two-level disentanglement strategy coupled with identifiable exponential-family priors, miVAE enables robust cross-modal alignment and interpretable analysis of neural coding, supporting zero-shot generalization across subjects without the need for subject-specific fine-tuning.
1. Model Structure and Generative Assumptions
miVAE models two observed modalities: neural population activity , recorded from neurons over time points, and dynamic visual stimuli , represented as movie frames. Each modality is processed by distinct neural networks (encoders) that parameterize (possibly time-varying) distributions over distinct subspaces of latent variables:
- The neural encoder maps neural activity and per-subject neuron coordinates to the neural-specific latent .
- The visual encoder maps the visual stimulus into stimulus-specific latents and a shared latent .
The generative model is defined as: with exponential-family priors: , are similarly parameterized, and . Conditional decoders and reconstruct each observed modality from their respective latents.
2. Two-Level Disentanglement and Identifiability
The latent space is partitioned into a shared component (), invariant across modalities and individuals, and modality-specific subspaces: neural-specific () and stimulus-specific (). The neural-specific latent captures idiosyncratic neuroanatomical and functional properties of individual subjects (via coordinate-dependent priors), while captures features of the stimulus uncorrelated with neural responses. The shared latent encodes the stimulus-driven features common across individuals and modalities.
Identifiability is achieved under the identifiable VAE (iVAE) framework by enforcing:
- Conditional independence: , ,
- Sufficiently rich exponential-family priors for and
- KL regularization to enforce proximity of inferred posteriors to their respective conditional priors
The multi-modal loss includes KL-divergence penalties for each latent and negative expected log-likelihoods (reconstructions). For example:
3. Variational Objective and Inference Procedures
Variational inference in miVAE is implemented using factorized (mean-field) variational distributions over latents conditioned on the observed modalities. The evidence lower bound (ELBO) combines the reconstruction and KL terms: where the joint prior factorizes as .
A novel aspect is the cross-modal loss, which swaps the conditional priors of one modality as the posterior for the other, reinforcing alignment in the shared latent space:
This structure ensures that captures the common signal linking neural and stimulus domains, while and absorb domain-specific variations.
4. Cross-Individual and Cross-Modal Alignment
A primary aim is zero-shot alignment of shared neural representations () across individuals. The model uses subject-specific priors for but conditions solely on observed stimuli (or population activity), ensuring transferability of the learned manifold.
After training, the shared latent distributions from neural and stimulus encoders are aligned by minimal affine transformations: with matrices and offsets fitted to match first and second moments (or optimized with a KL loss). This procedure yields cross-correlation scores exceeding 0.90 for held-out individuals and stimuli.
Experimental results demonstrate that miVAE achieves Pearson- scores of 0.8694 (Stage 1 encoding), 0.8809 (linear mapping), 0.9149 (nonlinear coding), and 0.8984–0.9635 (cross-individual alignment). Combining both multi-modal and cross-modal losses is necessary; removal of the neural-specific latent sharply impairs performance. A dimension (8 shared, 4+4 idiosyncratic) was found optimal.
5. Score-Based Attribution Mechanism
miVAE introduces a score-based attribution method to assign importance weights to neural units or stimulus features associated with each shared latent dimension. For identifying neural contributions to a shared latent, the Fisher score
is computed, with the marginal term typically approximated or omitted. The elementwise gradient magnitudes act as importance scores, partitioning the neural population into functionally distinct subpopulations.
On the stimulus side, analogous gradients () highlight spatial/temporal regions in the visual input that drive the shared code. Empirical analysis revealed that the most "important" neuron subset (approximately 700 units) achieves higher classification accuracy (91.29%) than the non-selected group (82.92%) or the full population (87.24%), evidencing refined discriminative specificity. Attribution maps in the stimulus domain preferentially identify edge- and luminance-sensitive movie regions.
6. Experimental Validation and Quantitative Findings
The evaluation protocol used the Sensorium 2023 V1 dataset (10 mice, ~78,000 neurons, 30 Hz, 36×64 movies), with training on data from 7 mice and testing on 3 held-out animals for zero-shot transfer assessment. Calcium traces were preprocessed to deconvolved spike rates and temporally aligned to video frames.
Training employed AdamW (batch 32, learning rate , 400 epochs, cosine annealing) on eight A100 GPUs. Key outcomes include:
| Task/Stage | Pearson-R () | Notes |
|---|---|---|
| Stage 1 encoding | 0.8694 | miVAE neural decoding |
| Stage 2 latent coding | 0.8809 (linear), 0.9149 (nonlinear) | alignment transforms |
| Cross-individual | 0.8984 (Stage 1), 0.9635 (Stage 2, nonlinear) | transfer |
| Neuron selection accuracy | 91.29% (attribution-selected), 82.92% (complement), 87.24% (all) | stimulus classification |
Ablation confirmed necessity of both the cross-modal and multi-modal losses, and of the neural-specific latent path. Larger dataset size (more mice, more trials, smaller neuron subsets) improved performance monotonically.
7. Interpretability and Broader Applicability
The disentangled shared latent () in miVAE encodes reproducible, stimulus-specific features robustly invariant to individual heterogeneity. Score-based attribution yields interpretable, distinct neural subpopulations with differing temporal dynamics and high stimulus discrimination capacity, while stimulus-side attribution emphasizes regions relevant for primary visual processing (e.g., those sensitive to spatial edges or luminance).
miVAE’s structure—leveraging modality-agnostic, identifiable priors and bidirectional multi-modal/cross-modal losses—is generalizable beyond V1. It can be adapted for analysis of other sensory cortices (auditory, somatosensory) and for integrative models combining behavioral and neural data in decision-making contexts. Applicability is ensured by defining appropriate domain-specific priors for each measurement setup, while the shared latent is always extracted via cross-modal variational objectives.
In summary, miVAE synthesizes two-level disentanglement, identifiable exponential-family probabilistic structure, advanced cross-modal variational learning, and attribution-based interpretability, providing a scalable and generalizable framework for the neurocomputational modeling of sensory representations and their individual-specific manifestations across large subject cohorts.