Papers
Topics
Authors
Recent
Search
2000 character limit reached

mlx-snn: SNN Library for Apple Silicon

Updated 7 March 2026
  • mlx-snn is a native spiking neural network library designed for Apple Silicon that leverages the MLX computation stack and a research-focused API.
  • It implements six neuron models, supports comprehensive surrogate gradients, and offers multiple spike encoding schemes for efficient SNN training.
  • Empirical results on MNIST show up to 97.28% accuracy with 2–2.5× faster training and 3–10× reduced GPU memory usage compared to snnTorch.

mlx-snn is a native spiking neural network (SNN) library designed for Apple Silicon and built on the Apple MLX array-computation stack. It addresses the gap in the SNN ecosystem where all major frameworks (such as snnTorch, Norse, SpikingJelly, and Lava) target PyTorch or custom backends, which are not optimized for the unified memory and compute characteristics of Apple hardware. mlx-snn offers a research-focused API with six neuron models, comprehensive surrogate gradient support, multiple spike encoding schemes (including EEG-specific), and a backpropagation-through-time (BPTT) training pipeline. Leveraging MLX features such as unified CPU/GPU memory, deferred/lazy evaluation, and composable functional transforms, mlx-snn provides efficient, low-overhead SNN experimentation and training on MacBook-class devices. Empirical evaluation on MNIST classification demonstrates up to 97.28% accuracy, with 2.0–2.5× speedup and 3–10× lower GPU memory footprint compared to snnTorch on the same Apple M3 Max hardware (Qin, 3 Mar 2026).

1. System Architecture and MLX Integration

mlx-snn is architected around the Apple MLX computation stack, which provides:

  • Unified Memory: Both CPU and GPU share physical RAM, obviating the need for device placement operations such as .to(device); tensors can fluidly move between compute units without explicit code intervention.
  • Lazy Evaluation: MLX builds deferred computation graphs that are only executed upon explicit calls (e.g., mx.eval(...)). This allows global optimization over all temporal steps within an SNN sequence, amortizing kernel launches and reducing peak memory usage.
  • Composable Transforms: Essential functionals such as mx.grad (automatic differentiation), mx.compile (JIT compilation), mx.vmap, mx.jvp, and mx.linearize enable transformation and vectorization of pure functions, enhancing both development flexibility and runtime performance in SNN contexts.

mlx-snn exploits these features as follows:

  • All neuron and encoder operations utilize mlx.core tensors.
  • Neuron states are implemented as immutable MLX arrays, stored in Python dictionaries (e.g., {'mem':..., 'syn':..., ...}), facilitating compatibility with pure-function transforms and JIT compilation.
  • The training loop constructs a global computation graph covering all T time steps and executes in a single shot. This minimizes repeated kernel launches and peak memory requirements.
  • Surrogate gradient computation follows a straight-through estimator (STE) pattern, employing mx.stop_gradient to circumvent MLX custom function VJP shape constraints while still supporting end-to-end differentiation.
  • The library components are modular, grouped into mlxsnn.neurons (six neuron types), mlxsnn.surrogates (four surrogate gradients), mlxsnn.encoding (four encoding schemes), and mlxsnn.training (BPTT, built-in loss functions, and utility APIs).
  • An API compatibility layer provides close alignment to snnTorch naming, constructor signatures, state management, and forward call patterns.

2. Neuron Models and Discrete-Time Update Equations

mlx-snn implements six neuron models with direct, explicit updates in discrete time (t=0,1,,T1t = 0, 1, \ldots, T-1):

  • Leaky Integrate-and-Fire (LIF):

U[t+1]=βU[t]+X[t+1]S[t]Vthr,S[t]=Θ(U[t]Vthr)U[t+1] = \beta \, U[t] + X[t+1] - S[t] \, V_{\mathrm{thr}}, \quad S[t]=\Theta(U[t] - V_{\mathrm{thr}})

with decay β\beta, threshold VthrV_{\mathrm{thr}}.

  • Integrate-and-Fire (IF):

U[t+1]=U[t]+X[t+1]S[t]VthrU[t+1] = U[t] + X[t+1] - S[t]\, V_{\mathrm{thr}}

(special case: β=1\beta = 1).

  • Izhikevich (2D dynamical system):

dvdt=0.04v2+5v+140u+I,dudt=a(bvu)\frac{dv}{dt} = 0.04 v^2 + 5v + 140 - u + I, \quad \frac{du}{dt} = a(bv-u)

Discretized via Euler; four presets (RS, IB, CH, FS).

  • Adaptive LIF (ALIF):

U[t+1]=βU[t]+X[t+1]S[t]Vthr,A[t+1]=ρA[t]+S[t],Veff[t]=Vthr+bA[t]U[t+1] = \beta U[t] + X[t+1] - S[t] V_{\mathrm{thr}}, \quad A[t+1] = \rho A[t] + S[t], \quad V_{\mathrm{eff}}[t] = V_{\mathrm{thr}} + b A[t]

Threshold adapts with spike count.

  • Synaptic (filtered input current):

Isyn[t+1]=αIsyn[t]+X[t+1],U[t+1]=βU[t]+Isyn[t+1]S[t]VthrI_{\mathrm{syn}}[t+1] = \alpha I_{\mathrm{syn}}[t] + X[t+1], \quad U[t+1] = \beta U[t] + I_{\mathrm{syn}}[t+1] - S[t] V_{\mathrm{thr}}

  • Alpha (dual-exponential synapse):

Iexc[t+1]=αIexc[t]+X[t+1],Iinh[t+1]=αIinh[t]+Iexc[t+1],U[t+1]=βU[t]+Iinh[t+1]S[t]VthrI_{\mathrm{exc}}[t+1] = \alpha I_{\mathrm{exc}}[t] + X[t+1], \quad I_{\mathrm{inh}}[t+1] = \alpha I_{\mathrm{inh}}[t] + I_{\mathrm{exc}}[t+1], \quad U[t+1] = \beta U[t] + I_{\mathrm{inh}}[t+1] - S[t] V_{\mathrm{thr}}

Models rapid rise and slower decay dynamics.

All updates are implemented using MLX arrays and permit batched, vectorized computation for scalable simulation.

3. Surrogate Gradient Techniques and Training Dynamics

Spiking neurons implement hard threshold nonlinearity, resulting in zero gradients almost everywhere; thus, training via backpropagation-through-time (BPTT) requires surrogate gradients. In mlx-snn, the straight-through estimator pattern is employed:

output=stop_grad(Θ(x)σ~(x))+σ~(x)\mathrm{output} = \mathrm{stop\_grad}(\Theta(x) - \tilde{\sigma}(x)) + \tilde{\sigma}(x)

This yields identity forward output (Θ(x)\Theta(x)), with gradient replaced by a smooth surrogate σ~(x)\tilde{\sigma}(x) in backpropagation.

Supported surrogates include:

  • Fast sigmoid:

σ~(x)=kx2(1+kx)+12\tilde{\sigma}(x) = \frac{kx}{2(1 + k|x|)} + \frac{1}{2}

  • Arctan:

σ~(x)=1πarctan(αx)+12\tilde{\sigma}(x) = \frac{1}{\pi}\arctan(\alpha x) + \frac{1}{2}

  • Straight-through (piecewise linear):

σ~(x)=clip(sx+0.5,0,1)\tilde{\sigma}(x) = \mathrm{clip}(s x + 0.5, 0, 1)

  • Custom: Arbitrary user-supplied smooth function, wrapped for STE relaxation.

The BPTT loop builds a full temporal graph, allowing gradients to be propagated through all intermediate neuron states. The workflow is fully compatible with MLX's autograd and lazy evaluation, making it possible to combine SNNs with arbitrary differentiable models.

4. Spike Encoding Schemes

mlx-snn provides four spike encoding methods for mapping continuous or time-series data into spike trains:

  • Rate Coding (Poisson): Each scalar feature x[0,1]x \in [0,1] is interpreted as the probability of emitting a spike at every time step, S[t]Bernoulli(x)S[t] \sim \mathrm{Bernoulli}(x).
  • Latency Coding (Time-to-First-Spike): xx is mapped to a spike time tspiket_{\mathrm{spike}}, either linearly or exponentially; a spike is emitted only once per sequence.
  • Delta Modulation: Emits a spike when x[t]x[t1]>θ|x[t] - x[t-1]| > \theta, suitable for temporal and streaming data.
  • EEG Encoder: Specialized encoding for multi-lead EEG data; supports rate, delta, and threshold-crossing on a per-channel basis, producing tensors of shape [T,B,C][T, B, C].

Encoders are designed for both prototypical vision experiments (e.g., MNIST) and bio-signal pipelines, and integrate seamlessly with subsequent network layers.

5. Empirical Benchmarking and Performance Analysis

mlx-snn was benchmarked on MNIST digit classification using an Apple M3 Max MacBook Pro (36 GB unified memory), running MLX 0.29.3 and Python 3.9.18. Baselines included snnTorch 0.9.4 with PyTorch 2.8.0 on both Apple MPS GPU and CPU backends.

The network topology was a 2-layer LIF SNN (784 inputs → hh hidden → 10 output, T=25T = 25 steps), with five hyperparameter configurations:

Configuration β\beta hh Acc (mlx-snn/MLX) Time (s) GPU Mem (MB) Acc (snnTorch/MPS) Time (s) GPU Mem (MB) Acc (snnTorch/CPU) Time (s)
C1 0.85 256 97.28% 4.0 61 98.00% 8.8 241 98.01% 12.8
C2 0.9 256 97.02% 4.3 65 98.03% 8.9 >241 97.97% 13.5
C3 256* 96.91% 2.4 98.03% 4.8 98.17% 16.7
C4 128 96.90% 4.3 97.84% 9.0 97.74% 10.9
C5 0.95 128 94.98% 4.4 97.09% 9.0 97.00% 11.1

(*Batch size and learning rates varied per configuration.)

Peak GPU memory usage for mlx-snn ranged 61–138 MB, while snnTorch on MPS required 241–1453 MB. Per-epoch training was 2.0–2.5× faster for mlx-snn compared to snnTorch GPU; memory usage reduced by 3–10×.

Surrogate gradient performance (C5, 10 epochs):

Surrogate Accuracy (%) Training Time (s)
Fast sigmoid 93.65 44.7
Arctan 92.44 43.8
Straight-through 46.28 46.3

6. Usage Paradigms and Implementation

mlx-snn is distributed via PyPI and installed using:

1
pip install mlx-snn

A typical SNN definition, using MLX and mlx-snn modules, follows snnTorch conventions:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlxsnn

class SpikingMLP(nn.Module):
    def __init__(self, num_steps=25, beta=0.9):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.lif1 = mlxsnn.Leaky(beta=beta)
        self.fc2 = nn.Linear(128, 10)
        self.lif2 = mlxsnn.Leaky(beta=beta, reset_mechanism="none")
        self.num_steps = num_steps

    def __call__(self, spikes_in):
        s1 = self.lif1.init_state(spikes_in.shape[1], 128)
        s2 = self.lif2.init_state(spikes_in.shape[1], 10)
        for t in range(self.num_steps):
            x = self.fc1(spikes_in[t])
            spk, s1 = self.lif1(x, s1)
            x = self.fc2(spk)
            _, s2 = self.lif2(x, s2)
        return s2["mem"]   # final membrane potentials

To optimize training performance:

  • Wrap the step function in mx.compile.
  • Build the full T-step computation graph before any mx.eval.
  • Use large batch sizes to amortize compute overhead.
  • Avoid in-place operations to ensure graph purity and compatiblity with JIT transforms.

7. Best Practices, Limitations, and Development Roadmap

Recommended practices for high accuracy and efficiency:

  • β0.850.9\beta \approx 0.85\text{–}0.9 for hidden layer decay.
  • Hidden size h=256h = 256 yields top-1 accuracy (97.28%).
  • Adam optimizer, learning rate 1×1031 \times 10^{-3}.
  • Fast-sigmoid surrogate (k=25k=25) optimally balances accuracy and speed.

Deployment on Apple Silicon is streamlined: explicit device transfers are unnecessary; MLX automatically manages memory placement. The use of mx.profile aids in analyzing memory and compute overheads.

Limitations in version 0.2.1:

  1. JIT compilation of temporal SNN loops is not yet supported (mx.compile cannot process Python for loops).
  2. No integrated neuromorphic dataset support (e.g., N-MNIST, DVS-Gesture, SHD).
  3. Published benchmarks are restricted to MNIST.
  4. The STE surrogate introduces a roughly 1% accuracy penalty versus native custom VJPs.

Future development aims include:

  • v0.3.0: Liquid State Machine (LSM) support, reservoir computing modules, EEG classification.
  • v0.4.0: Full mx.compile optimization of forward passes, neuromorphic data loaders, visualization suite, benchmarks on complex datasets (CIFAR, DVS).
  • v1.0.0: Stable API, extensive validation, release of comprehensive documentation, and a stable PyPI package.

In summary, mlx-snn enables SNN research, development, and production on Apple hardware, delivering competitive accuracy and resource efficiency by exploiting MLX features while maintaining a familiar and portable API for the broader spiking neural network community (Qin, 3 Mar 2026).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to mlx-snn.