Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
126 tokens/sec
GPT-4o
47 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Inferring stochastic low-rank recurrent neural networks from neural data (2406.16749v3)

Published 24 Jun 2024 in cs.LG, q-bio.NC, and stat.ML

Abstract: A central aim in computational neuroscience is to relate the activity of large populations of neurons to an underlying dynamical system. Models of these neural dynamics should ideally be both interpretable and fit the observed data well. Low-rank recurrent neural networks (RNNs) exhibit such interpretability by having tractable dynamics. However, it is unclear how to best fit low-rank RNNs to data consisting of noisy observations of an underlying stochastic system. Here, we propose to fit stochastic low-rank RNNs with variational sequential Monte Carlo methods. We validate our method on several datasets consisting of both continuous and spiking neural data, where we obtain lower dimensional latent dynamics than current state of the art methods. Additionally, for low-rank models with piecewise linear nonlinearities, we show how to efficiently identify all fixed points in polynomial rather than exponential cost in the number of units, making analysis of the inferred dynamics tractable for large RNNs. Our method both elucidates the dynamical systems underlying experimental recordings and provides a generative model whose trajectories match observed variability.

Citations (4)

Summary

  • The paper leverages variational sequential Monte Carlo methods to infer low-dimensional dynamics from noisy neural data via low-rank RNNs.
  • It demonstrates robust performance across synthetic, EEG, and spiking datasets by accurately recovering underlying neural dynamics and latent oscillations.
  • The study provides an efficient algorithm for fixed point analysis, significantly reducing computational costs while ensuring theoretical rigor in model stability.

Inferring Stochastic Low-Rank Recurrent Neural Networks from Neural Data

The paper "Inferring stochastic low-rank recurrent neural networks from neural data" introduces a novel approach for fitting and analyzing low-rank recurrent neural networks (RNNs) to noisy, high-dimensional neural data. The proposed method leverages variational sequential Monte Carlo (SMC) methods to infer the underlying low-dimensional dynamics, making it feasible to handle large networks and substantial trial-to-trial variability. The authors validate their approach using several synthetic and real-world datasets, demonstrating both the mathematical rigor and practical utility of their method.

Theoretical Foundations and Methods

Low-Rank RNNs and Their Dynamics

Low-rank RNNs are a subset of RNNs characterized by the low-rank structure of their connectivity matrix, which allows for a direct mapping from high-dimensional neural activity to a low-dimensional space. Mathematically, the low-rank structure is exploited to rewrite the high-dimensional dynamics in terms of fewer latent variables, significantly reducing the computational complexity. The dynamics of such a low-rank RNN are governed by: τdxdt=x(t)+Jϕ(x(t))+Γxξ(t)\tau \frac{d\mathbf{x}}{dt} = -\mathbf{x}(t) + \mathbf{J} \phi(\mathbf{x}(t)) + \Gamma_\mathbf{x} \xi(t) where J\mathbf{J} is of low-rank (i.e., J=MNT\mathbf{J} = \mathbf{MN}^\mathsf{T}); M\mathbf{M} and N\mathbf{N} are low-dimensional matrices, and ξ(t)\xi(t) is a white noise process. By assuming an initial condition in the subspace spanned by M\mathbf{M}, the dynamics can be projected into a lower-dimensional space.

Sequential Monte Carlo for Stochastic RNNs

Fitting stochastic RNNs to neural data is performed using variational sequential Monte Carlo methods, which involve approximating the posterior distributions of the latent variables through a set of particles. This method efficiently handles the variational and stochastic aspects of the model by iteratively updating particles using a resampling-proposal-reweighting mechanism. The proposal distribution is crucial and is chosen to balance computational complexity and inference accuracy. For instance, for nonlinear observations, an encoding distribution parameterized by a causal convolutional neural network is optimized.

Empirical Validation

Teacher-Student Setups

The authors validate their approach through various teacher-student setups. By generating synthetic datasets from RNNs trained to emit oscillatory or fixed-point dynamics, the method is shown to recover the underlying systems accurately. For both continuous and Poisson-distributed spiking data, the inferred RNNs match the dynamics of the teacher networks, as evidenced by consistent autocorrelation functions and interspike interval distributions.

EEG Data

When applied to EEG data, the method outperforms state-of-the-art approaches in terms of dimensionality reduction. The EEG data, characterized by high trial-to-trial variability, is effectively modeled by stochastic transitions within the RNN framework. With only three latent dimensions, the proposed method achieves comparable reconstruction accuracy to models utilizing deterministic transitions and higher dimensional latent spaces.

Hippocampal Spiking Data

The method is further applied to spiking data from the rat hippocampus, demonstrating its utility in a fundamentally different neural context. The inferred RNNs not only reproduce the spiking statistics at the single-neuron and population levels but also reveal latent oscillations reflective of the local field potential (LFP) theta rhythms. This dual capability suggests a powerful approach for linking spiking activity with LFP dynamics, a key interest in computational neuroscience.

Monkey Reaching Task

Finally, the RNN model is employed in a more complex behavioral paradigm involving a macaque monkey performing a reaching task. By conditioning on reach target positions, the model successfully decodes reach trajectories and uncovers structured latent dynamics. The ability to generalize to unseen reach conditions highlights the robustness and general applicability of the proposed method.

Fixed Point Analysis

A significant theoretical contribution is the efficient algorithm for finding fixed points in low-rank RNNs with piecewise linear activation functions. The authors derive a bound for this problem, showing that all fixed points can be found with an exponential reduction in computational cost compared to naive methods. This allows for an in-depth analysis of the RNN dynamics, including stability and attractor structures.

Implications and Speculations

By introducing stochastic transitions and leveraging SMC methods, the authors provide a robust framework for capturing the trial-to-trial variability inherent in neural data. This capability is critical for developing accurate, interpretable models of neural dynamics that can be generalized across different experimental conditions. The approach presents significant advancements in terms of computational efficiency, particularly relevant for large-scale neural recordings.

Future research could extend these methods to include more complex noise models or investigate multi-regional interactions within the brain. Additionally, the integration of multimodal datasets could further enhance our understanding of the relationships between spiking activity, local field potentials, and behavioral outputs.

Conclusion

The research presents a comprehensive framework for fitting low-rank recurrent neural networks to neural data characterized by high-dimensional, noisy observations. The combination of low-rank structures, stochastic transitions, and SMC methods enables the extraction of interpretable, low-dimensional latent dynamics, providing a powerful tool for computational neuroscience. The rigorous validation and theoretical contributions make this work a valuable addition to the field, paving the way for further explorations into the dynamics underlying neural activity.