Papers
Topics
Authors
Recent
2000 character limit reached

KAN-Dreamer: KAN-Based MBRL Architecture

Updated 15 December 2025
  • KAN-Dreamer is a model-based reinforcement learning architecture that replaces traditional MLPs and CNNs in DreamerV3 with Kolmogorov–Arnold Networks for improved interpretability.
  • Empirical evaluations on the DeepMind Control Suite’s walker_walk task demonstrate competitive sample efficiency and highlight trade-offs in throughput and computational overhead.
  • The system emphasizes parameter efficiency in low-dimensional latent prediction tasks while facing challenges in visual perception and actor-critic modules due to architectural mismatches.

KAN-Dreamer denotes a series of model-based reinforcement learning (MBRL) architectures that benchmark Kolmogorov–Arnold Networks (KANs) and their efficient variant FastKAN as functional components in DreamerV3-style world models. By substituting conventional multilayer perceptrons (MLPs) and convolutional sub-modules within DreamerV3 with KAN-based alternatives, KAN-Dreamer systematically evaluates KANs as differentiable function approximators in three principal subsystems: visual perception, latent prediction (reward and continuation heads), and behavior learning (actor-critic). Empirical assessments on DeepMind Control Suite’s walker_walk task illuminate the regimes in which KANs and FastKAN match or underperform relative to established MLP/CNN baselines, characterizing trade-offs in parameter efficiency, interpretability, and computational cost (Shi et al., 8 Dec 2025).

1. Theoretical Foundations: Kolmogorov–Arnold Networks

KANs are motivated by the Kolmogorov–Arnold representation theorem, which asserts that any continuous multivariate function f:[0,1]dRf : [0,1]^d \rightarrow \mathbb{R} can be constructed via a finite superposition of continuous univariate functions and addition. Operationally, a KAN layer computes

f(x1,,xd)=q=1mφq(p=1dγpq(xp)),f(x_1,\dots,x_d) = \sum_{q=1}^{m} \varphi_q\left(\sum_{p=1}^{d} \gamma_{pq}(x_p)\right),

where γpq:RR\gamma_{pq} : \mathbb{R} \to \mathbb{R} (inner functions) and φq:RR\varphi_q : \mathbb{R} \to \mathbb{R} (outer functions) are trainable univariate mappings merged with a linear branch in typical application. KANs generalize universal approximation, offering structural efficiency and potential analytic interpretability through their explicit decomposition (Shi et al., 8 Dec 2025).

FastKAN addresses the prohibitive computational overhead of spline-based KANs by employing a bank of GG Gaussian radial basis functions (RBFs) per edge:

φ(x)=wbb(x)+wsk=1Gckexp((xμkh)2)\varphi(x) = w_b b(x) + w_s \sum_{k=1}^{G} c_k\,\exp\left(-\left(\frac{x - \mu_k}{h}\right)^2\right)

with fixed base activation b(x)b(x) (e.g., SiLU), wb=ws=1w_b=w_s=1, fixed RBF centers μk\mu_k uniformly spaced on [2,2][-2,2], bandwidth hh, and trainable coefficients ckc_k. This construct enables fully vectorized and simplified implementations in JAX, exclusively learning ckc_k per edge and eliminating adaptive grid mechanisms (Shi et al., 8 Dec 2025).

2. KAN-Dreamer Architectural Integration

KAN-Dreamer is instantiated as a drop-in replacement strategy for MLP/CNNs in the following DreamerV3 functional units, preserving the GRU-based RSSM:

  • Visual Perception (Encoder/Decoder):
    • Baseline: CNN encoder / transposed CNN decoder.
    • KAN-Vis: Image flattening and fully connected KAN layers.
    • FKAN-Vis: Fully connected FastKAN layers.
  • Latent Prediction (Reward & Continue Heads):
    • Baseline: 2-layer MLP on (ht,zt)(h_t, z_t).
    • KAN-Pred: 2-layer KAN (G=G=8, spline order k=k=3).
    • FKAN-Pred: 2-layer FastKAN (G=G=8 RBFs).
  • Behavior Learning (Actor / Critic):
    • Baseline: 2-layer MLP on st=(ht,zt)s_t=(h_t, z_t).
    • KAN-AC: 2-layer KAN (G=G=8, k=k=3).
    • FKAN-AC: 2-layer FastKAN.

In all regimes, KAN/ FastKAN layers are parameter-matched to baselines (total ≃10.5M), utilizing fixed input clamp ranges ([−5,5] for KAN; [−2,2] for FastKAN) and consistent training hyperparameters (e.g., replay capacity 5×1065 \times 10^6, batch size 16, sequence length 64, LaProp optimizer, RMSNorm + SiLU activations). Grid sizes and hidden unit counts are subsystem-specific: KAN-Vis uses 8 units per FC, KAN-Pred 20, and KAN-AC 24; FKAN counterparts use 12, 30, and 34 respectively. All FastKAN operations utilize a single tensordot contraction and fixed grid management (Shi et al., 8 Dec 2025).

3. Empirical Evaluation and Ablation Analysis

KAN-Dreamer’s evaluation on walker_walk consists of subsystem ablations, comparing final returns, sample efficiency, throughput (FPS), parameter count, and loss trajectories.

Subsystem Model Final Score FPS (Pol.) FPS (Trn.) Params
Perception Baseline 977 62 15.9k 10.49M
KAN-Vis 940 72 18.5k 10.40M
FKAN-Vis 950 89 22.8k 10.51M
Prediction Baseline 977 62 15.9k 10.49M
KAN-Pred 965 45 11.5k 10.45M
FKAN-Pred 951 61 15.5k 10.51M
Behavior Baseline 977 62 15.9k 10.49M
KAN-AC 957 31 7.8k 10.48M
FKAN-AC 964 57 14.5k 10.47M

Key observations include:

  • Sample Efficiency: Baseline achieves reward 900 at ~250k steps. FKAN-Pred requires ~220k, KAN-Pred ~250k, FKAN-AC ~520k, FKAN-Vis ~750k, KAN-AC and KAN-Vis slower at ~900k and ~930k steps respectively.
  • Wall-Clock Performance: FKAN-Pred matches baseline at ~1.0 h; KAN-Pred at ~1.5 h; FKAN-Vis at ~2.5 h; KAN-Vis at ~3.5 h; KAN-AC exceeds 8 h.
  • Loss Dynamics: KAN/ FastKAN-Vis plateau at reconstruction loss ~100 (indicating poor visual fidelity vs. baseline), whereas reward/continue predictors and policy/value heads converge comparably in loss but with higher variance and slower value approximation for actor-critic.

4. Functional Insights and Interpretative Context

The empirical parity of FastKAN and MLP architectures in the low-dimensional latent prediction regime (Reward and Continue heads) results from the alignment of task properties (deterministic, low-dimensional target functions) with the structural decomposition advocated by the Kolmogorov–Arnold theorem. FastKAN’s entirely vectorized, grid-based inference substantially reduces spline overhead, yielding throughput comparable to MLPs.

Conversely, KANs and FastKAN underperform in vision and policy/value estimation due to architectural incompatibilities:

  • Visual Perception Limitation: Fully connected KANs lack convolutional spatial biases, impeding learning of pixel locality and resulting in poor reconstructions and diminished sample efficiency.
  • Policy/Value Approximation Instability: The recursive, non-stationary critic landscape imposes learning dynamics for which current KAN/ FastKAN parameterizations are insufficient, manifesting as high-variance and slow convergence.
  • Overhead: KAN layers exhibit ~2× computational overhead relative to MLPs. FastKAN addresses this with minimal loss of task-relevant expressivity in latent heads, but not in high-dimensional or structurally non-aligned tasks (Shi et al., 8 Dec 2025).

5. Implementation and Vectorization Strategies

The FastKAN implementation is designed for JAX, employing a fully vectorized approach for scalability. Grid centers μk\mu_k are initialized uniformly over [2,2][-2,2] for all edges, with no adaptive updates, and all operations are compiled to a single tensordot contraction for batch efficiency. Only the RBF weights ckc_k are learnable per (output, input, gridpoint). This implementation ensures that FastKAN, in the tested low-dimensional subsystems, can saturate MLP-equivalent training throughput.

A representative pseudocode fragment:

1
2
3
4
5
6
7
8
9
10
11
import jax.numpy as jnp

def FastKANLayer(x, W, b, c, mu, h):
    # x: [batch, d_in]
    # W, b: linear branch params
    # c: [d_out, d_in, G], mu: [G], h: bandwidth
    y_base = x @ W + b  # [batch, d_out]
    x_exp = x[..., None] - mu[None, None, :]  # [batch, d_in, G]
    phi = jnp.exp(-((x_exp / h) ** 2))        # [batch, d_in, G]
    y_rbf = jnp.tensordot(phi, c, axes=([1, 2], [1, 2]))  # [batch, d_out]
    return y_base + y_rbf

6. Limitations and Prospective Advances

KAN-Dreamer’s primary limitations arise from representational mismatch and computational constraints:

  • Spatiotemporal Bias Absence: Off-the-shelf KANs do not retain convolutional priors necessary for high-fidelity image modeling, limiting their utility in visual systems.
  • Critic/Policy Head Instability: Value head instability under KAN/FastKAN parameterizations suggests a need for enhanced compositional or recurrent architectures.
  • Scalability: Despite FastKAN improvements, computational efficiency and expressivity remain suboptimal outside low-dimensional, structurally aligned tasks.

Future research avenues include:

  1. Tailored Hyperparameters and Architectures: Subsystem-specific co-optimization for grid size, RBF bandwidth, and input normalization.
  2. Symbolic Interpretability: Applying symbolic regression to analytic forms of φ\varphi and γ\gamma—facilitated by regression reward heads—to uncover explicit function mappings.
  3. Recurrent and KAN-RSSM Models: Integration of stateful KAN-RNNs or KAN-parametrized transition models for enhanced long-horizon modeling.
  4. Spatial Priors in KANs: Developing convolutional or locality-aware KANs for improved visual perception (Shi et al., 8 Dec 2025).

A plausible implication is that KAN and FastKAN, when appropriately integrated into world model subsystems matching their strengths, can serve as interpretable, competitive alternatives to MLPs in sample-efficient MBRL, but that significant architectural innovation is required to extend these advantages to high-dimensional and recursive domains.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Whiteboard

Follow Topic

Get notified by email when new papers are published related to KAN-Dreamer.