Papers
Topics
Authors
Recent
2000 character limit reached

snnTorch: Spiking Neural Networks in PyTorch

Updated 20 November 2025
  • snnTorch is an open-source Python library that integrates spiking neural network architectures into PyTorch, enabling modular and extensible SNN research.
  • It offers a range of neuron models, surrogate gradient methods, and BP-compatible input encoding, ensuring compatibility with standard deep learning workflows.
  • The library supports quantization-aware training and hardware acceleration, making it suitable for both high-performance and embedded neuromorphic applications.

snnTorch is an open-source Python library designed for constructing and training spiking neural networks (SNNs) within the PyTorch framework. Its emphasis is on modular, extensible, and efficient integration of spiking neuron models, input encoding schemes, surrogate-gradient training, and compatibility with standard deep learning pipelines. snnTorch exposes neuron modules, spike encoders, loss functions, and surrogate gradients as first-class PyTorch objects, enabling seamless research and prototyping in computational neuroscience, neuromorphic engineering, and event-based machine learning domains (Eshraghian et al., 2021).

1. Design Goals and Architecture

snnTorch was developed to address the need for SNN workflows that mirror established practices in deep learning, facilitating gradient-based optimization, input processing, and model evaluation. Its architecture is based on several guiding principles:

  • Modularity: All neuron models, encoders, and surrogates are implemented as independent PyTorch modules, allowing arbitrary compositions and easy swapping.
  • Extensibility: New neuron models, custom surrogates, and local learning rules can be introduced by subclassing base classes or integrating external modules.
  • Efficiency: Supports GPU/CPU backends, optional quantization-aware routines, and hardware-specific accelerators. Key state variables are managed by the library, reducing user-side overhead.
  • Interoperability: Designed to enable usage alongside standard PyTorch features including torch.optim, nn.Sequential, DataLoader, and mixed-precision/inference utilities (Sun et al., 2022, Eshraghian et al., 2021).

The API mirrors common PyTorch design idioms, e.g., .forward(), .init_hidden(), and seamless integration in training loops.

2. Neuron Models and Dynamics

snnTorch supplies canonical spiking neuron models with rigorous discrete-time formulations:

  • Leaky Integrate-and-Fire (LIF): The default dynamic is

Vt=βVt1+It,St=Θ(VtVth),VtVt(1St)V_t = \beta V_{t-1} + I_t, \quad S_t = \Theta(V_t - V_{th}), \quad V_t \leftarrow V_t (1 - S_t)

where membrane decay β=exp(Δt/τm)\beta = \exp(-\Delta t/\tau_m), threshold VthV_{th}, and hard reset Vt0V_t \to 0 on spike emission. Soft-reset variants subtract VthV_{th} on spike (Vora et al., 4 Feb 2025, Eshraghian et al., 2021).

  • Current-based LIF (CUBA): Synaptic current low-pass followed by leaky integration:

it=αit1+wxt,ut=βut1+itzt1uthri_t = \alpha i_{t-1} + w x_t, \quad u_t = \beta u_{t-1} + i_t - z_{t-1} u_{thr}

with separate synaptic and membrane time constants.

  • Recurrent LIF/SLSTM/Synaptic: Extensions incorporate feedback via recurrent weights or specialized kernels; the spiking LSTM integrates gating with threshold-spike-reset logic (Eshraghian et al., 2021).

All models are solved with forward Euler discretization per time step, with spike generation via binary thresholding. The state of each neuron is tracked explicitly, and reset policies (hard, soft) are configurable.

3. Surrogate Gradient Methods

Direct gradient propagation through the Heaviside spike function is ill-posed due to non-differentiability. snnTorch embeds several surrogate gradient functions for robust backpropagation through time (BPTT):

  • Fast Sigmoid Surrogate: Used in DRiVE and IPU releases, defined as

σ(u)=12(1+u)\sigma'(u) = \frac{1}{2(1 + |u|)}

and implemented with surrogate.fast_sigmoid() (Vora et al., 4 Feb 2025, Sun et al., 2022).

  • Arctan (default):

S~(U)=1πarctan(π(Uθ)),S~U=1π(1+(π(Uθ))2)\tilde S(U) = \frac{1}{\pi}\arctan(\pi(U-\theta)),\quad \frac{\partial \tilde S}{\partial U} = \frac{1}{\pi(1 + (\pi(U-\theta))^2)}

Used as a standard smooth surrogate (Eshraghian et al., 2021).

  • Straight-through estimator, triangular, and custom surrogates are also provided. Any surrogate can be plugged in via the spike_grad argument to neuron modules.

Surrogates allow standard PyTorch-based .backward() and gradient optimization.

4. Input Encoding and Output Decoding

snnTorch supplies BP-compatible encoders mapping continuous or categorical input into spike trains:

  • Rate (Poisson) Coding: Each feature x[0,1]x \in [0,1] is treated as the firing probability per time step, yielding a Bernoulli process over TT steps:

P(S[t]=1)=xP(S[t]=1) = x

Used, e.g., in DRiVE for grayscale image encoding (Vora et al., 4 Feb 2025, Eshraghian et al., 2021).

  • Latency (time-to-first-spike): Translates feature values into precise spike times.
  • Delta (event) Encoding: Emits spikes when input changes by a set threshold.

Spike outputs can be decoded as spike counts (rate), time-to-first-spike, or, in temporal tasks, by custom losses over spike rasters.

5. Training, Quantization, and Hardware Integration

Training Workflow

Standard SNN training in snnTorch involves Poisson-rate input encoding, explicit unrolling in time, and application of cross-entropy or regression losses computed over spike rates or temporal patterns. Built-in loss functions, e.g., ce_rate_loss() and ce_temporal_loss(), are available (Eshraghian et al., 2021).

A representative training loop for a multi-layer SNN includes membrane state initialization, forward unrolling, loss calculation over accumulated spikes, backpropagation with surrogate gradients, and optimizer step. An example is supplied for the DRiVE model (Vora et al., 4 Feb 2025).

Quantization-Aware Training

snnTorch supports quantization for both weights (via integration with Brevitas) and membrane potentials (SQUAT methodology) (Venkatesh et al., 15 Apr 2024). Uniform and threshold-centered (exponential) state quantization strategies are implemented, with parameters for bit-width, quantization range, and exponential warping. The state_quant callback is used to enforce discretization on state variables within neuron modules. Ablation studies demonstrate that 8-bit quantization results in negligible accuracy loss, whereas 2–4 bit regimes require combined QAT and state quantization for acceptable performance.

Regime 8-bit Acc. (FashionMNIST) 4-bit Acc. 2-bit Acc.
Full-precision 90.5
QAT+SQUAT 90.5 88.3 47.6

Hardware Accelerators

snnTorch provides specific backends for Graphcore IPUs, exploiting MIMD parallelism, custom C++ codelets for at-spike logic, and half-precision support for memory/compute efficiency (Sun et al., 2022). Benchmarks indicate up to a 10-fold speedup and 3–5× improved throughput per watt compared to NVIDIA A100 on SNN benchmarks. Population coding is recommended for maximizing accelerator utilization at single time steps. The IPU module is import-compatible and does not alter the public API.

6. Example Applications and Performance

Recent works concretize snnTorch’s utility as a research platform:

  • DRiVE Vehicle Detection: Implements a 3-layer SNN (LIF with β=0.95\beta=0.95) using snnTorch to classify static images via Poisson spike encoding. Key configuration:
    • Input: 128×128 grayscale images, Poisson encoding, T=50 steps.
    • Network: [16384] → [64] → [64] → [2] LIF layers with BatchNorm.
    • Loss: Cross-entropy rate loss, AdamW, early stopping.
    • Test accuracy: 94.82%, AUC=0.99, outperforming or matching S-ResNet38, Spikformer V2, and CSNN-blurr9 on the same dataset (Vora et al., 4 Feb 2025).
  • Neuromorphic and Embedded Deployments: IPU acceleration and state/weight quantization pathways demonstrate snnTorch’s focus on efficient, scalable SNN training for both resource-rich and edge platforms (Sun et al., 2022, Venkatesh et al., 15 Apr 2024).

7. Extending and Customizing snnTorch

snnTorch is designed for straightforward extension:

  • Custom Neuron Models: Subclass base neuron modules with arbitrary dynamics or reset rules.
  • Custom Surrogate Gradients: Define new PyTorch autograd.Function for S~(u)\tilde S(u).
  • Hybrid Learning Rules: Combine built-in backprop or local/plasticity rules (e.g., e-prop, DECOLLE) via the functional API and custom loss definitions.
  • Integration with third-party quantization libraries: Supported directly for weight QAT; see Brevitas example (Venkatesh et al., 15 Apr 2024).

Developer documentation and tutorial resources are hosted at https://snntorch.readthedocs.io/en/latest/tutorials/index.html (Eshraghian et al., 2021).


snnTorch advances the research ecosystem for SNNs by providing modular, scalable, and well-optimized components for deep, recurrent, and quantized spiking architectures. Its empirical validation on visual and temporal tasks, alongside hardware-specific optimizations and quantization-aware routines, establishes it as an adaptable platform bridging neuroscience-inspired models and modern machine learning workflows (Vora et al., 4 Feb 2025, Sun et al., 2022, Venkatesh et al., 15 Apr 2024, Eshraghian et al., 2021).

Slide Deck Streamline Icon: https://streamlinehq.com

Whiteboard

Forward Email Streamline Icon: https://streamlinehq.com

Follow Topic

Get notified by email when new papers are published related to SNNTorch.