Papers
Topics
Authors
Recent
Search
2000 character limit reached

SymTorch: A PyTorch Library for Symbolic Distillation

Updated 2 July 2026
  • SymTorch is a PyTorch-native library that performs symbolic distillation by replacing trained neural subnetworks with analytic, human-readable equations.
  • It integrates PySR for evolutionary symbolic regression, using custom operators and Pareto optimization to balance error and complexity.
  • SymTorch facilitates model interpretability, inference acceleration, and out-of-distribution analysis across various architectures.

SymTorch is a PyTorch-native library that enables the symbolic distillation of neural networks: the process of replacing a trained subnetwork ff with a closed-form symbolic equation ss that accurately approximates its input–output mapping. By automating the integration of symbolic regression into deep learning workflows, SymTorch removes the engineering barriers that have hindered widespread adoption of symbolic distillation, facilitating interpretability, inference acceleration, and opportunities for out-of-distribution analysis across diverse deep learning architectures (Tan et al., 24 Feb 2026).

1. Conceptual Foundations and Motivation

Symbolic distillation operates between local explanation methods (e.g., LIME, SHAP) and full mechanistic interpretability by yielding explicit analytic surrogates for neural modules. For a pretrained block f:RdRDf: \mathbb{R}^d \to \mathbb{R}^D, the core objective is to find sSs^* \in \mathcal{S}, with S\mathcal{S} a space of analytic expressions (e.g., trees over +,×,sin,exp+, \times, \sin, \exp, etc.), via

s=argminsSi=1Nf(xi)s(xi)2+λComplexity(s),s^* = \arg\min_{s \in \mathcal{S}} \sum_{i=1}^N \|f(x_i) - s(x_i)\|^2 + \lambda \mathrm{Complexity}(s)\,,

where Complexity(s)\mathrm{Complexity}(s) measures formula size as node count. Interpretability comes from the explicit, human-readable form of ss, acceleration results from the reduced computational cost of analytic formulas (notably on CPU) compared to full neural modules, and out-of-distribution extrapolation becomes feasible due to the analytic nature of s(x)s(x).

2. SymTorch Architecture

SymTorch is implemented as a drop-in PyTorch extension built around the SymbolicModel class, which inherits from torch.nn.Module. Key architectural components are:

  • Module Wrapping: Any PyTorch submodule or callable can be wrapped by SymbolicModel, which registers forward hooks to capture and cache all input and output tensors during inference.
  • Dataflow and Caching: Input–output pairs ss0 are pulled from GPU to CPU and cached as the model is run over a calibration set. These cached samples are stored for repeated symbolic regression without additional GPU passes.
  • Neural-to-Symbolic Switching: After symbolic regression, invoking switch_to_symbolic(complexity=K) swaps the neural block for a surrogate ss1 of specified complexity; switch_to_block() reverts to the original block. As the wrapper maintains full PyTorch nn.Module semantics, serialization with torch.save, model loading, and compatibility with torch.compile remain intact.

3. PySR Backbone and Symbolic Regression Integration

SymTorch automates symbolic regression by integrating PySR, an evolutionary algorithm-based symbolic regression engine. Core features include:

  • Operator and Search Space Customization: Users can specify permitted operators (e.g., ["+", "*", "inv", "sin", "exp"]), and apply variable transforms such as ss2 to reduce input dimensionality.
  • Evolutionary Search Dynamics: PySR maintains populations of candidate expressions, applying selection (tournament based on loss and parsimony), crossover, mutation, formula simplification, and constant optimization. The Pareto front is continuously updated to track the error–complexity trade-off.
  • Automated Wrapper Integration: The symbolic regression workflow—(1) capturing data, (2) invoking PySR, (3) retrieving Pareto-optimal surrogates—is encapsulated in a single call to .distill(sr_params, fit_params).
  • Equation Selection: Among Pareto candidates, PySR scores each ss3 using

ss4

and selects the surrogate maximizing the gain per complexity increment.

4. Practical Implementations and Case Studies

SymTorch is demonstrated across a range of network architectures and domains:

  • Graph Neural Networks (GNNs): Distillation targets edge-MLPs in ss5-body simulations, recovering interpretable force laws such as ss6.
  • Physics-Informed Neural Networks (PINNs): For the 1D heat equation, the framework extracts the analytic solution ss7 from a PINN trained on sparse data.
  • Transformer Models: Application to LLMs involves replacing specific MLP (SwiGLU) blocks with symbolic surrogates, supporting inference-time performance studies.
  • Minimal Example Workflow:

f:RdRDf: \mathbb{R}^d \to \mathbb{R}^D2

5. Empirical Evaluation

Key results highlight throughput and accuracy trade-offs:

Setting Perplexity Throughput (tok/s) Notes
Baseline (Qwen2.5-1.5B) 10.62 4879 No symbolic distillation
PCA only (32→8 dims) 13.73 (+3.11) 5200
PCA + SymTorch (3/28 MLPs) 13.76 (+3.14) 5280 (+8.3%) 8.3% speedup

Additional case study observations:

  • GNN (edge message): All model variants recover ss8 force law with high fidelity; bottleneck/pruned models yield the cleanest surrogates.
  • PINN (1D heat): PINN achieves MSE ss9; SymTorch distills analytic solution with two-decimal precision.
  • Computational Overheads: SR on small MLPs typically completes within minutes on modern CPUs (400–2000 iterations); increased operator set cardinality or input dimensionality can increase duration to hours, representing the principal cost driver.

6. Limitations and Prospects

Principal limitations:

  • Scalability: Symbolic regression complexity increases exponentially with input dimensionality, operator set size, and dataset scale.
  • Formula Complexity vs. Human Parsability: Node count does not guarantee interpretability—f:RdRDf: \mathbb{R}^d \to \mathbb{R}^D0 and f:RdRDf: \mathbb{R}^d \to \mathbb{R}^D1 have the same cost.
  • Expressivity Ceiling: Symbolic surrogates may not capture highly nonlinear or complex block behaviors, leaving some modules refractory to closed-form distillation.

Potential improvements identified by the authors:

  • Adoption of learned linear projections (beyond PCA) to optimize pre-SR dimensionality reduction.
  • Dynamic adaptation of the SR operator set tailored to interim progress.
  • Cross-domain validation of surrogate generalization on out-of-distribution data.
  • Layer selection heuristics to identify transformer blocks offering optimal speed–accuracy trade-offs upon symbolic distillation.

By addressing data transfer, caching, model management, and surrogate switching entirely within a native PyTorch interface, SymTorch positions symbolic distillation as a single-call extension to established deep learning workflows and demonstrates the practical feasibility of analytic interpretability in scientific and engineering settings (Tan et al., 24 Feb 2026).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to SymTorch.