SymTorch: A PyTorch Library for Symbolic Distillation
- SymTorch is a PyTorch-native library that performs symbolic distillation by replacing trained neural subnetworks with analytic, human-readable equations.
- It integrates PySR for evolutionary symbolic regression, using custom operators and Pareto optimization to balance error and complexity.
- SymTorch facilitates model interpretability, inference acceleration, and out-of-distribution analysis across various architectures.
SymTorch is a PyTorch-native library that enables the symbolic distillation of neural networks: the process of replacing a trained subnetwork with a closed-form symbolic equation that accurately approximates its input–output mapping. By automating the integration of symbolic regression into deep learning workflows, SymTorch removes the engineering barriers that have hindered widespread adoption of symbolic distillation, facilitating interpretability, inference acceleration, and opportunities for out-of-distribution analysis across diverse deep learning architectures (Tan et al., 24 Feb 2026).
1. Conceptual Foundations and Motivation
Symbolic distillation operates between local explanation methods (e.g., LIME, SHAP) and full mechanistic interpretability by yielding explicit analytic surrogates for neural modules. For a pretrained block , the core objective is to find , with a space of analytic expressions (e.g., trees over , etc.), via
where measures formula size as node count. Interpretability comes from the explicit, human-readable form of , acceleration results from the reduced computational cost of analytic formulas (notably on CPU) compared to full neural modules, and out-of-distribution extrapolation becomes feasible due to the analytic nature of .
2. SymTorch Architecture
SymTorch is implemented as a drop-in PyTorch extension built around the SymbolicModel class, which inherits from torch.nn.Module. Key architectural components are:
- Module Wrapping: Any PyTorch submodule or callable can be wrapped by
SymbolicModel, which registers forward hooks to capture and cache all input and output tensors during inference. - Dataflow and Caching: Input–output pairs 0 are pulled from GPU to CPU and cached as the model is run over a calibration set. These cached samples are stored for repeated symbolic regression without additional GPU passes.
- Neural-to-Symbolic Switching: After symbolic regression, invoking
switch_to_symbolic(complexity=K)swaps the neural block for a surrogate 1 of specified complexity;switch_to_block()reverts to the original block. As the wrapper maintains full PyTorchnn.Modulesemantics, serialization withtorch.save, model loading, and compatibility withtorch.compileremain intact.
3. PySR Backbone and Symbolic Regression Integration
SymTorch automates symbolic regression by integrating PySR, an evolutionary algorithm-based symbolic regression engine. Core features include:
- Operator and Search Space Customization: Users can specify permitted operators (e.g.,
["+", "*", "inv", "sin", "exp"]), and apply variable transforms such as 2 to reduce input dimensionality. - Evolutionary Search Dynamics: PySR maintains populations of candidate expressions, applying selection (tournament based on loss and parsimony), crossover, mutation, formula simplification, and constant optimization. The Pareto front is continuously updated to track the error–complexity trade-off.
- Automated Wrapper Integration: The symbolic regression workflow—(1) capturing data, (2) invoking PySR, (3) retrieving Pareto-optimal surrogates—is encapsulated in a single call to
.distill(sr_params, fit_params). - Equation Selection: Among Pareto candidates, PySR scores each 3 using
4
and selects the surrogate maximizing the gain per complexity increment.
4. Practical Implementations and Case Studies
SymTorch is demonstrated across a range of network architectures and domains:
- Graph Neural Networks (GNNs): Distillation targets edge-MLPs in 5-body simulations, recovering interpretable force laws such as 6.
- Physics-Informed Neural Networks (PINNs): For the 1D heat equation, the framework extracts the analytic solution 7 from a PINN trained on sparse data.
- Transformer Models: Application to LLMs involves replacing specific MLP (SwiGLU) blocks with symbolic surrogates, supporting inference-time performance studies.
- Minimal Example Workflow:
2
5. Empirical Evaluation
Key results highlight throughput and accuracy trade-offs:
| Setting | Perplexity | Throughput (tok/s) | Notes |
|---|---|---|---|
| Baseline (Qwen2.5-1.5B) | 10.62 | 4879 | No symbolic distillation |
| PCA only (32→8 dims) | 13.73 (+3.11) | 5200 | |
| PCA + SymTorch (3/28 MLPs) | 13.76 (+3.14) | 5280 (+8.3%) | 8.3% speedup |
Additional case study observations:
- GNN (edge message): All model variants recover 8 force law with high fidelity; bottleneck/pruned models yield the cleanest surrogates.
- PINN (1D heat): PINN achieves MSE 9; SymTorch distills analytic solution with two-decimal precision.
- Computational Overheads: SR on small MLPs typically completes within minutes on modern CPUs (400–2000 iterations); increased operator set cardinality or input dimensionality can increase duration to hours, representing the principal cost driver.
6. Limitations and Prospects
Principal limitations:
- Scalability: Symbolic regression complexity increases exponentially with input dimensionality, operator set size, and dataset scale.
- Formula Complexity vs. Human Parsability: Node count does not guarantee interpretability—0 and 1 have the same cost.
- Expressivity Ceiling: Symbolic surrogates may not capture highly nonlinear or complex block behaviors, leaving some modules refractory to closed-form distillation.
Potential improvements identified by the authors:
- Adoption of learned linear projections (beyond PCA) to optimize pre-SR dimensionality reduction.
- Dynamic adaptation of the SR operator set tailored to interim progress.
- Cross-domain validation of surrogate generalization on out-of-distribution data.
- Layer selection heuristics to identify transformer blocks offering optimal speed–accuracy trade-offs upon symbolic distillation.
By addressing data transfer, caching, model management, and surrogate switching entirely within a native PyTorch interface, SymTorch positions symbolic distillation as a single-call extension to established deep learning workflows and demonstrates the practical feasibility of analytic interpretability in scientific and engineering settings (Tan et al., 24 Feb 2026).