KAN-Dreamer: KAN-Based MBRL Architecture
- KAN-Dreamer is a model-based reinforcement learning architecture that replaces traditional MLPs and CNNs in DreamerV3 with Kolmogorov–Arnold Networks for improved interpretability.
- Empirical evaluations on the DeepMind Control Suite’s walker_walk task demonstrate competitive sample efficiency and highlight trade-offs in throughput and computational overhead.
- The system emphasizes parameter efficiency in low-dimensional latent prediction tasks while facing challenges in visual perception and actor-critic modules due to architectural mismatches.
KAN-Dreamer denotes a series of model-based reinforcement learning (MBRL) architectures that benchmark Kolmogorov–Arnold Networks (KANs) and their efficient variant FastKAN as functional components in DreamerV3-style world models. By substituting conventional multilayer perceptrons (MLPs) and convolutional sub-modules within DreamerV3 with KAN-based alternatives, KAN-Dreamer systematically evaluates KANs as differentiable function approximators in three principal subsystems: visual perception, latent prediction (reward and continuation heads), and behavior learning (actor-critic). Empirical assessments on DeepMind Control Suite’s walker_walk task illuminate the regimes in which KANs and FastKAN match or underperform relative to established MLP/CNN baselines, characterizing trade-offs in parameter efficiency, interpretability, and computational cost (Shi et al., 8 Dec 2025).
1. Theoretical Foundations: Kolmogorov–Arnold Networks
KANs are motivated by the Kolmogorov–Arnold representation theorem, which asserts that any continuous multivariate function can be constructed via a finite superposition of continuous univariate functions and addition. Operationally, a KAN layer computes
where (inner functions) and (outer functions) are trainable univariate mappings merged with a linear branch in typical application. KANs generalize universal approximation, offering structural efficiency and potential analytic interpretability through their explicit decomposition (Shi et al., 8 Dec 2025).
FastKAN addresses the prohibitive computational overhead of spline-based KANs by employing a bank of Gaussian radial basis functions (RBFs) per edge:
with fixed base activation (e.g., SiLU), , fixed RBF centers uniformly spaced on , bandwidth , and trainable coefficients . This construct enables fully vectorized and simplified implementations in JAX, exclusively learning per edge and eliminating adaptive grid mechanisms (Shi et al., 8 Dec 2025).
2. KAN-Dreamer Architectural Integration
KAN-Dreamer is instantiated as a drop-in replacement strategy for MLP/CNNs in the following DreamerV3 functional units, preserving the GRU-based RSSM:
- Visual Perception (Encoder/Decoder):
- Baseline: CNN encoder / transposed CNN decoder.
- KAN-Vis: Image flattening and fully connected KAN layers.
- FKAN-Vis: Fully connected FastKAN layers.
- Latent Prediction (Reward & Continue Heads):
- Baseline: 2-layer MLP on .
- KAN-Pred: 2-layer KAN (8, spline order 3).
- FKAN-Pred: 2-layer FastKAN (8 RBFs).
- Behavior Learning (Actor / Critic):
- Baseline: 2-layer MLP on .
- KAN-AC: 2-layer KAN (8, 3).
- FKAN-AC: 2-layer FastKAN.
In all regimes, KAN/ FastKAN layers are parameter-matched to baselines (total ≃10.5M), utilizing fixed input clamp ranges ([−5,5] for KAN; [−2,2] for FastKAN) and consistent training hyperparameters (e.g., replay capacity , batch size 16, sequence length 64, LaProp optimizer, RMSNorm + SiLU activations). Grid sizes and hidden unit counts are subsystem-specific: KAN-Vis uses 8 units per FC, KAN-Pred 20, and KAN-AC 24; FKAN counterparts use 12, 30, and 34 respectively. All FastKAN operations utilize a single tensordot contraction and fixed grid management (Shi et al., 8 Dec 2025).
3. Empirical Evaluation and Ablation Analysis
KAN-Dreamer’s evaluation on walker_walk consists of subsystem ablations, comparing final returns, sample efficiency, throughput (FPS), parameter count, and loss trajectories.
| Subsystem | Model | Final Score | FPS (Pol.) | FPS (Trn.) | Params |
|---|---|---|---|---|---|
| Perception | Baseline | 977 | 62 | 15.9k | 10.49M |
| KAN-Vis | 940 | 72 | 18.5k | 10.40M | |
| FKAN-Vis | 950 | 89 | 22.8k | 10.51M | |
| Prediction | Baseline | 977 | 62 | 15.9k | 10.49M |
| KAN-Pred | 965 | 45 | 11.5k | 10.45M | |
| FKAN-Pred | 951 | 61 | 15.5k | 10.51M | |
| Behavior | Baseline | 977 | 62 | 15.9k | 10.49M |
| KAN-AC | 957 | 31 | 7.8k | 10.48M | |
| FKAN-AC | 964 | 57 | 14.5k | 10.47M |
Key observations include:
- Sample Efficiency: Baseline achieves reward 900 at ~250k steps. FKAN-Pred requires ~220k, KAN-Pred ~250k, FKAN-AC ~520k, FKAN-Vis ~750k, KAN-AC and KAN-Vis slower at ~900k and ~930k steps respectively.
- Wall-Clock Performance: FKAN-Pred matches baseline at ~1.0 h; KAN-Pred at ~1.5 h; FKAN-Vis at ~2.5 h; KAN-Vis at ~3.5 h; KAN-AC exceeds 8 h.
- Loss Dynamics: KAN/ FastKAN-Vis plateau at reconstruction loss ~100 (indicating poor visual fidelity vs. baseline), whereas reward/continue predictors and policy/value heads converge comparably in loss but with higher variance and slower value approximation for actor-critic.
4. Functional Insights and Interpretative Context
The empirical parity of FastKAN and MLP architectures in the low-dimensional latent prediction regime (Reward and Continue heads) results from the alignment of task properties (deterministic, low-dimensional target functions) with the structural decomposition advocated by the Kolmogorov–Arnold theorem. FastKAN’s entirely vectorized, grid-based inference substantially reduces spline overhead, yielding throughput comparable to MLPs.
Conversely, KANs and FastKAN underperform in vision and policy/value estimation due to architectural incompatibilities:
- Visual Perception Limitation: Fully connected KANs lack convolutional spatial biases, impeding learning of pixel locality and resulting in poor reconstructions and diminished sample efficiency.
- Policy/Value Approximation Instability: The recursive, non-stationary critic landscape imposes learning dynamics for which current KAN/ FastKAN parameterizations are insufficient, manifesting as high-variance and slow convergence.
- Overhead: KAN layers exhibit ~2× computational overhead relative to MLPs. FastKAN addresses this with minimal loss of task-relevant expressivity in latent heads, but not in high-dimensional or structurally non-aligned tasks (Shi et al., 8 Dec 2025).
5. Implementation and Vectorization Strategies
The FastKAN implementation is designed for JAX, employing a fully vectorized approach for scalability. Grid centers are initialized uniformly over for all edges, with no adaptive updates, and all operations are compiled to a single tensordot contraction for batch efficiency. Only the RBF weights are learnable per (output, input, gridpoint). This implementation ensures that FastKAN, in the tested low-dimensional subsystems, can saturate MLP-equivalent training throughput.
A representative pseudocode fragment:
1 2 3 4 5 6 7 8 9 10 11 |
import jax.numpy as jnp def FastKANLayer(x, W, b, c, mu, h): # x: [batch, d_in] # W, b: linear branch params # c: [d_out, d_in, G], mu: [G], h: bandwidth y_base = x @ W + b # [batch, d_out] x_exp = x[..., None] - mu[None, None, :] # [batch, d_in, G] phi = jnp.exp(-((x_exp / h) ** 2)) # [batch, d_in, G] y_rbf = jnp.tensordot(phi, c, axes=([1, 2], [1, 2])) # [batch, d_out] return y_base + y_rbf |
6. Limitations and Prospective Advances
KAN-Dreamer’s primary limitations arise from representational mismatch and computational constraints:
- Spatiotemporal Bias Absence: Off-the-shelf KANs do not retain convolutional priors necessary for high-fidelity image modeling, limiting their utility in visual systems.
- Critic/Policy Head Instability: Value head instability under KAN/FastKAN parameterizations suggests a need for enhanced compositional or recurrent architectures.
- Scalability: Despite FastKAN improvements, computational efficiency and expressivity remain suboptimal outside low-dimensional, structurally aligned tasks.
Future research avenues include:
- Tailored Hyperparameters and Architectures: Subsystem-specific co-optimization for grid size, RBF bandwidth, and input normalization.
- Symbolic Interpretability: Applying symbolic regression to analytic forms of and —facilitated by regression reward heads—to uncover explicit function mappings.
- Recurrent and KAN-RSSM Models: Integration of stateful KAN-RNNs or KAN-parametrized transition models for enhanced long-horizon modeling.
- Spatial Priors in KANs: Developing convolutional or locality-aware KANs for improved visual perception (Shi et al., 8 Dec 2025).
A plausible implication is that KAN and FastKAN, when appropriately integrated into world model subsystems matching their strengths, can serve as interpretable, competitive alternatives to MLPs in sample-efficient MBRL, but that significant architectural innovation is required to extend these advantages to high-dimensional and recursive domains.