GAMformer: In-Context GAM Estimation
- GAMformer is a transformer-based framework that performs rapid, single-pass estimation of generalized additive models (GAMs) via in-context learning, enabling direct recovery of interpretable shape functions.
- The method encodes tabular features into quantile-based one-hot tokens and processes them with a dual-module self-attention architecture, bypassing iterative fitting.
- Empirical evaluations show that GAMformer achieves competitive accuracy against XGBoost and EBMs on synthetic and real-world datasets while maintaining clear visual interpretability.
GAMformer is a transformer-based modeling approach for the rapid estimation of Generalized Additive Models (GAMs) via in-context learning. Eschewing the traditional requirement for iterative model fitting, GAMformer enables single-pass, nonparametric recovery of feature shape functions for tabular data, with design choices that support interpretability and empirical competitiveness on classification benchmarks. The model is trained exclusively on synthetic data yet achieves strong generalization to real-world tabular datasets (Mueller et al., 2024).
1. Generalized Additive Model Framework
A Generalized Additive Model expresses the target as an additive sum of univariate, potentially nonlinear effects: where denotes the -dimensional input vector, and each is a shape function representing the partial effect for feature . In supervised learning, the model predicts either a real or categorical response , with a link function as follows: The resulting structure ensures direct interpretability, as plotting versus visualizes the effect of individual features.
2. In-Context Learning Paradigm
GAMformer fundamentally reframes shape function estimation as an in-context learning problem, leveraging the transformer’s capacity to jointly process labeled train examples and unlabeled test instances in a unified sequence. At inference, a small context set 0 and one or more test points 1 are provided. Feature values are discretized into 2 quantile-based bins, with each continuous or categorical feature mapped to a one-hot encoding. Each token represents a pairing of a feature-bin and the associated label embedding, formed into a 3 grid—which is then augmented by test rows with dummy label tokens. The transformer outputs a tensor 4 encoding shape tables, where 5 is the number of output classes. Predictions are computed by bin index lookup and summation: 6 This design enables immediate, one-shot estimation of main effect functions from the observed context.
3. Model Architecture
GAMformer’s architecture comprises three principal components:
- Embedding Layer: All features (continuous and categorical) are mapped to 64 bins, each represented by a one-hot vector (7). These are processed through a small MLP to yield embeddings of dimension 8. Label embeddings are added class-specifically.
- Transformer Encoder: The core consists of 12 layers, each with dual-module self-attention: one head operates across features within each row (example-wise), while the other attends across examples in each feature column. This permutation-equivariant scheme sidesteps the need for positional encodings. The encoder contains 950.5M parameters.
- Shape-Function Decoder: For every feature-class pair, embeddings across training examples are aggregated (mean over examples sharing the class label), and a shared MLP decodes these representations into the 0-length shape vector for each class, yielding the discrete nonparametric 1.
This enables seamless, main-effect function table estimation and interpretable inference for each test instance.
4. Training Methodology
GAMformer is trained solely on synthetic tabular datasets generated from two distinct priors:
- Structural Causal Models (SCMs): Random graph structures with stochastic edge functions.
- Gaussian Process (GP) Priors: Random function draws with varying kernels.
Each synthetic data batch is split into context and test subsets. The training objective is cross-entropy loss over held-out test instances: 2 No curvature regularization is used—smoothness of 3 is induced by pretraining on smoothly varying synthetic priors. Optimization is performed using SGD or Adam. No per-dataset fitting or hyperparameter tuning is performed at inference.
5. Empirical Evaluation and Results
GAMformer is evaluated on a battery of synthetic and real-world tabular tasks:
- Toy Examples: On linearly separable and polynomial datasets, GAMformer accurately recovers shape plots (e.g., 4). Smoother estimates are observed versus EBM.
- Synthetic and OpenML Benchmarks: Performance matches XGBoost and EBMs on ~30 classification datasets (up to 2,000 rows, 10 features). Second-order interactions close gaps on "XOR"-like patterns. Critical-difference plots confirm statistical parity with EBMs and XGBoost; with second-order terms, parity or slight outperformance is observed.
- MIMIC-II (ICU Mortality): GAMformer shape functions for features such as Age and PaO5/FiO6 replicate clinical U-shapes and highlight an imputation artifact around PF ratio 7 325. Missing-value patients are more sharply isolated versus EBM.
- Ablations: Robustness holds for 8 up to 2,000 and 9 up to 10; main effects show superior sample efficiency at small 0. Like all transformers, performance degrades if context size 1 greatly exceeds the 2 seen in pretraining.
6. Interpretability, Advantages, and Limitations
GAMformer's direct output of discrete, binned shape tables preserves GAM interpretability—each 3 has an immediate, visual partial effects plot. No iterative optimization or dataset-specific hyperparameter search is necessary; model inference is strictly a single forward pass. Empirical accuracy matches or exceeds tree and neural boosted baselines, and binned representations naturally accommodate discontinuities, feature artifacts, and missing values.
The approach presents several limitations: inference quadratic in 4 and 5 limits practical application to small/mid-size tabular datasets; length extrapolation is constrained—performance plateaus or degrades when presented with contexts much longer than those seen during training; only main effects and greedy pairwise feature terms are directly modeled—higher-order interactions incur exponential feature blowup.
A plausible implication is that GAMformer's paradigm opens the possibility for amortized, domain-agnostic tabular estimation, provided computational and interaction modeling constraints are managed.
7. Context and Significance
GAMformer introduces a new direction in the space of interpretable tabular modeling by unifying recent insights from transformer-based in-context learning with longstanding advances in additive, inherently interpretable regression and classification. By leveraging pure synthetic data and an architecture tailored for permutation-equivariant, context-driven inference, GAMformer eliminates traditional fitting loops, achieves strong empirical accuracy, and preserves the core interpretive benefits of GAMs, marking a significant advancement for generalized, single-pass, nonparametric tabular modeling (Mueller et al., 2024).