mlx-snn: SNN Library for Apple Silicon
- mlx-snn is a native spiking neural network library designed for Apple Silicon that leverages the MLX computation stack and a research-focused API.
- It implements six neuron models, supports comprehensive surrogate gradients, and offers multiple spike encoding schemes for efficient SNN training.
- Empirical results on MNIST show up to 97.28% accuracy with 2–2.5× faster training and 3–10× reduced GPU memory usage compared to snnTorch.
mlx-snn is a native spiking neural network (SNN) library designed for Apple Silicon and built on the Apple MLX array-computation stack. It addresses the gap in the SNN ecosystem where all major frameworks (such as snnTorch, Norse, SpikingJelly, and Lava) target PyTorch or custom backends, which are not optimized for the unified memory and compute characteristics of Apple hardware. mlx-snn offers a research-focused API with six neuron models, comprehensive surrogate gradient support, multiple spike encoding schemes (including EEG-specific), and a backpropagation-through-time (BPTT) training pipeline. Leveraging MLX features such as unified CPU/GPU memory, deferred/lazy evaluation, and composable functional transforms, mlx-snn provides efficient, low-overhead SNN experimentation and training on MacBook-class devices. Empirical evaluation on MNIST classification demonstrates up to 97.28% accuracy, with 2.0–2.5× speedup and 3–10× lower GPU memory footprint compared to snnTorch on the same Apple M3 Max hardware (Qin, 3 Mar 2026).
1. System Architecture and MLX Integration
mlx-snn is architected around the Apple MLX computation stack, which provides:
- Unified Memory: Both CPU and GPU share physical RAM, obviating the need for device placement operations such as
.to(device); tensors can fluidly move between compute units without explicit code intervention. - Lazy Evaluation: MLX builds deferred computation graphs that are only executed upon explicit calls (e.g.,
mx.eval(...)). This allows global optimization over all temporal steps within an SNN sequence, amortizing kernel launches and reducing peak memory usage. - Composable Transforms: Essential functionals such as
mx.grad(automatic differentiation),mx.compile(JIT compilation),mx.vmap,mx.jvp, andmx.linearizeenable transformation and vectorization of pure functions, enhancing both development flexibility and runtime performance in SNN contexts.
mlx-snn exploits these features as follows:
- All neuron and encoder operations utilize
mlx.coretensors. - Neuron states are implemented as immutable MLX arrays, stored in Python dictionaries (e.g.,
{'mem':..., 'syn':..., ...}), facilitating compatibility with pure-function transforms and JIT compilation. - The training loop constructs a global computation graph covering all T time steps and executes in a single shot. This minimizes repeated kernel launches and peak memory requirements.
- Surrogate gradient computation follows a straight-through estimator (STE) pattern, employing
mx.stop_gradientto circumvent MLX custom function VJP shape constraints while still supporting end-to-end differentiation. - The library components are modular, grouped into
mlxsnn.neurons(six neuron types),mlxsnn.surrogates(four surrogate gradients),mlxsnn.encoding(four encoding schemes), andmlxsnn.training(BPTT, built-in loss functions, and utility APIs). - An API compatibility layer provides close alignment to snnTorch naming, constructor signatures, state management, and
forwardcall patterns.
2. Neuron Models and Discrete-Time Update Equations
mlx-snn implements six neuron models with direct, explicit updates in discrete time ():
- Leaky Integrate-and-Fire (LIF):
with decay , threshold .
- Integrate-and-Fire (IF):
(special case: ).
- Izhikevich (2D dynamical system):
Discretized via Euler; four presets (RS, IB, CH, FS).
- Adaptive LIF (ALIF):
Threshold adapts with spike count.
- Synaptic (filtered input current):
- Alpha (dual-exponential synapse):
Models rapid rise and slower decay dynamics.
All updates are implemented using MLX arrays and permit batched, vectorized computation for scalable simulation.
3. Surrogate Gradient Techniques and Training Dynamics
Spiking neurons implement hard threshold nonlinearity, resulting in zero gradients almost everywhere; thus, training via backpropagation-through-time (BPTT) requires surrogate gradients. In mlx-snn, the straight-through estimator pattern is employed:
This yields identity forward output (), with gradient replaced by a smooth surrogate in backpropagation.
Supported surrogates include:
- Fast sigmoid:
- Arctan:
- Straight-through (piecewise linear):
- Custom: Arbitrary user-supplied smooth function, wrapped for STE relaxation.
The BPTT loop builds a full temporal graph, allowing gradients to be propagated through all intermediate neuron states. The workflow is fully compatible with MLX's autograd and lazy evaluation, making it possible to combine SNNs with arbitrary differentiable models.
4. Spike Encoding Schemes
mlx-snn provides four spike encoding methods for mapping continuous or time-series data into spike trains:
- Rate Coding (Poisson): Each scalar feature is interpreted as the probability of emitting a spike at every time step, .
- Latency Coding (Time-to-First-Spike): is mapped to a spike time , either linearly or exponentially; a spike is emitted only once per sequence.
- Delta Modulation: Emits a spike when , suitable for temporal and streaming data.
- EEG Encoder: Specialized encoding for multi-lead EEG data; supports rate, delta, and threshold-crossing on a per-channel basis, producing tensors of shape .
Encoders are designed for both prototypical vision experiments (e.g., MNIST) and bio-signal pipelines, and integrate seamlessly with subsequent network layers.
5. Empirical Benchmarking and Performance Analysis
mlx-snn was benchmarked on MNIST digit classification using an Apple M3 Max MacBook Pro (36 GB unified memory), running MLX 0.29.3 and Python 3.9.18. Baselines included snnTorch 0.9.4 with PyTorch 2.8.0 on both Apple MPS GPU and CPU backends.
The network topology was a 2-layer LIF SNN (784 inputs → hidden → 10 output, steps), with five hyperparameter configurations:
| Configuration | Acc (mlx-snn/MLX) | Time (s) | GPU Mem (MB) | Acc (snnTorch/MPS) | Time (s) | GPU Mem (MB) | Acc (snnTorch/CPU) | Time (s) | ||
|---|---|---|---|---|---|---|---|---|---|---|
| C1 | 0.85 | 256 | 97.28% | 4.0 | 61 | 98.00% | 8.8 | 241 | 98.01% | 12.8 |
| C2 | 0.9 | 256 | 97.02% | 4.3 | 65 | 98.03% | 8.9 | >241 | 97.97% | 13.5 |
| C3 | — | 256* | 96.91% | 2.4 | — | 98.03% | 4.8 | — | 98.17% | 16.7 |
| C4 | — | 128 | 96.90% | 4.3 | — | 97.84% | 9.0 | — | 97.74% | 10.9 |
| C5 | 0.95 | 128 | 94.98% | 4.4 | — | 97.09% | 9.0 | — | 97.00% | 11.1 |
(*Batch size and learning rates varied per configuration.)
Peak GPU memory usage for mlx-snn ranged 61–138 MB, while snnTorch on MPS required 241–1453 MB. Per-epoch training was 2.0–2.5× faster for mlx-snn compared to snnTorch GPU; memory usage reduced by 3–10×.
Surrogate gradient performance (C5, 10 epochs):
| Surrogate | Accuracy (%) | Training Time (s) |
|---|---|---|
| Fast sigmoid | 93.65 | 44.7 |
| Arctan | 92.44 | 43.8 |
| Straight-through | 46.28 | 46.3 |
6. Usage Paradigms and Implementation
mlx-snn is distributed via PyPI and installed using:
1 |
pip install mlx-snn |
A typical SNN definition, using MLX and mlx-snn modules, follows snnTorch conventions:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim import mlxsnn class SpikingMLP(nn.Module): def __init__(self, num_steps=25, beta=0.9): super().__init__() self.fc1 = nn.Linear(784, 128) self.lif1 = mlxsnn.Leaky(beta=beta) self.fc2 = nn.Linear(128, 10) self.lif2 = mlxsnn.Leaky(beta=beta, reset_mechanism="none") self.num_steps = num_steps def __call__(self, spikes_in): s1 = self.lif1.init_state(spikes_in.shape[1], 128) s2 = self.lif2.init_state(spikes_in.shape[1], 10) for t in range(self.num_steps): x = self.fc1(spikes_in[t]) spk, s1 = self.lif1(x, s1) x = self.fc2(spk) _, s2 = self.lif2(x, s2) return s2["mem"] # final membrane potentials |
To optimize training performance:
- Wrap the step function in
mx.compile. - Build the full T-step computation graph before any
mx.eval. - Use large batch sizes to amortize compute overhead.
- Avoid in-place operations to ensure graph purity and compatiblity with JIT transforms.
7. Best Practices, Limitations, and Development Roadmap
Recommended practices for high accuracy and efficiency:
- for hidden layer decay.
- Hidden size yields top-1 accuracy (97.28%).
- Adam optimizer, learning rate .
- Fast-sigmoid surrogate () optimally balances accuracy and speed.
Deployment on Apple Silicon is streamlined: explicit device transfers are unnecessary; MLX automatically manages memory placement. The use of mx.profile aids in analyzing memory and compute overheads.
Limitations in version 0.2.1:
- JIT compilation of temporal SNN loops is not yet supported (
mx.compilecannot process Pythonforloops). - No integrated neuromorphic dataset support (e.g., N-MNIST, DVS-Gesture, SHD).
- Published benchmarks are restricted to MNIST.
- The STE surrogate introduces a roughly 1% accuracy penalty versus native custom VJPs.
Future development aims include:
- v0.3.0: Liquid State Machine (LSM) support, reservoir computing modules, EEG classification.
- v0.4.0: Full
mx.compileoptimization of forward passes, neuromorphic data loaders, visualization suite, benchmarks on complex datasets (CIFAR, DVS). - v1.0.0: Stable API, extensive validation, release of comprehensive documentation, and a stable PyPI package.
In summary, mlx-snn enables SNN research, development, and production on Apple hardware, delivering competitive accuracy and resource efficiency by exploiting MLX features while maintaining a familiar and portable API for the broader spiking neural network community (Qin, 3 Mar 2026).