Subjective Timescale Transformers (STT)
- Subjective Timescale Transformers (STT) are decoder-only models that apply conditional, token-based computation along the temporal axis to optimize efficiency.
- They integrate a lightweight Transition Network to predict residual updates and employ Bayesian surprise signals to decide block execution dynamically.
- Empirical results show that STT achieves up to 37.5% savings in self-attention FLOPs and 25% KV-cache reduction while balancing computational efficiency and accuracy.
Subjective Timescale Transformers (STT) are a class of decoder-only Transformer architectures designed to enhance computational efficiency by introducing conditional computation along the temporal axis. Unlike standard Transformers, which execute a uniform, dense computation at each block for all tokens, STT selectively skips blocks for specific tokens based on learned Bayesian surprise signals. This mechanism enables the model to determine both "where and when to compute," reducing both self-attention computations and KV-cache requirements, with explicit routing governed by predicted and observed token-wise state transitions (Wieser et al., 26 Nov 2025).
1. Architectural Framework and Temporal Conditional Computation
An STT is a modification to the conventional decoder-only Transformer stack, where every other standard block is replaced by an STT layer. Each STT layer consists of two principal components:
- A lightweight Transition Network (TPN) that generates a temporal prior by predicting the next-token residual update based on the previous token's processed state.
- A full Transformer block (self-attention and feedforward) whose execution is dynamically gated by surprise signals computed from model predictions and observed outcomes.
At each timestep and layer , the TPN computes a prior residual prediction from . The true residual is produced by running the full block on . Both quantities are compared using surprise metrics. A router then determines, per token, whether to execute or skip the full block. This approach extends conditional computation—historically applied only across model depth—into the temporal domain, allowing the network to economize computation on a token-by-token basis (Wieser et al., 26 Nov 2025).
2. Transition Network and Temporal Change Hypothesis
Each STT layer contains a TPN defined as: where are the TPN parameters. The full block computes: The pair forms the "temporal change hypothesis," quantifying expected state evolution and providing the basis for gating decisions. The TPN in all experiments is a 2-layer MLP with hidden size equal to model dimension , trained with MSE loss weighted by .
3. Surprise Signal Computation
STT computes two core surprise metrics per token and layer:
- Expected Change (CE):
CE quantifies whether the predicted residual (change hypothesis) accounts for the update more accurately than the static prior.
- Unexpected Change (CU):
where is a moving average over the sequence and are offsets. CU captures if the token's static change magnitude is unusually large relative to prior tokens.
The gating score for each token is then given by: with denoting the sigmoid and as learnable inverse temperatures annealed during training, ensuring the gating sharpens over time.
4. Routing and Execution Mechanism
Routing is implemented as a fixed-capacity Top-K selection over the continuous gating score. For a chosen capacity and sequence length , the tokens with the highest per sequence are selected for full-block execution, while the rest are assigned identity residual updates (skipped blocks). During inference, a causal router (small MLP) trained to imitate non-causal Top-K behavior facilitates autoregressive operation by consuming only current and previous token representations: .
Pseudocode for a single STT layer with fixed capacity is:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
Input: {x_t}^{(ℓ-1)} for t=1…T
1. For t=1…T:
a) Predict prior residual: Δ̂_t = TPN^{(ℓ)}(x_{t-1}^{(ℓ)})
b) Compute actual residual: Δ_t = TransformerBlock^{(ℓ)}(x_t^{(ℓ-1)})
c) D_st[t] = ||Δ_t||^2 / d, D_ch[t] = ||Δ_t - Δ̂_t||^2 / d
d) CE[t] = D_st[t] - (D_ch[t] - log o_ce)
CU[t] = D_st[t] - m_cu * MA(D_st)[t]
e) g_cont[t] = sigmoid(β_ce*CE[t]) + sigmoid(β_cu*CU[t]) - product
2. Select indices S = TopK(g_cont, k)
3. For each t in S:
x_t^{(ℓ)} = x_t^{(ℓ-1)} + g_cont[t]*TransformerBlock^{(ℓ)}(x_t^{(ℓ-1)})
For t not in S:
x_t^{(ℓ)} = x_t^{(ℓ-1)}
Output: {x_t}^{(ℓ)} |
5. Efficient KV-Cache Management
In standard Transformers, each layer appends the key-value pairs for every token to the respective layer’s KV-cache. In STT, only tokens for which the block is executed contribute new pairs, decreasing the memory requirements. For a fixed capacity , a fraction of tokens per STT layer add to the cache. As STT layers are placed every other block, the relative KV-cache saving per layer is: For example, with , the KV-cache size is reduced by per layer.
6. Empirical Results: Compute-Accuracy Tradeoffs and Training Dynamics
Experiments utilize the Qwen2.5-0.5B backbone with alternating STT and standard blocks. Key outcomes include:
- Compute savings: With fixed , self-attention cost per two-layer block pair is averaged to $0.625$ (full block + STT block at $0.25$), yielding savings in self-attention FLOPs; KV-cache consumption is reduced by .
- Dynamic-capacity variant: Allowing learned during training results in self-attention and KV-cache savings.
- Accuracy trade-offs: For $0.5$B parameter models at capacity, there is a marked decrease in standard LM benchmarks compared to the dense model—for instance, MMLU: , ARC-C: , HellaSwag: , TruthfulQA: , WinoGrande: .
Training dynamics reflect a shift in gating strategy: early on, Unexpected Change (novelty-driven signal, CU) dominates block execution; as TPNs improve predictive accuracy, Expected Change (CE) increasingly drives gating. Causal router temperatures are annealed from $0.1$ to $100$ to sharpen decisions. In dynamic capacity settings, deeper layers process fewer tokens, consistent with hierarchical predictive coding.
7. Implementation Details and Experimental Configuration
Key implementation parameters:
- Backbone: Pre-trained Qwen2.5-0.5B; STT layers alternate with standard blocks.
- Transition Network (TPN): 2-layer MLP, hidden size ; MSE loss ().
- Routing Losses: MSE ($0.05$), causal routing BCE ($0.01$), sparseness loss ($0.001$) for dynamic capacity.
- Optimizer: AdamW (, , , weight decay $0.01$).
- Learning Rates: $1$e–5 (backbone), $1$e–3 (TPN), $1$e–2 (router).
- Batching/Hardware: 1024-token blocks, per-device batch $8$, gradient accumulation $32$ (effective $256$); bfloat16, activation checkpointing.
- Software Stack: PyTorch, HuggingFace Transformers, Hydra, Accelerate with FSDP. Code is seeded ($42$), trained on mixed text corpora, evaluated via lm-eval harness.
Together, these components produce a model that adaptively controls when to compute using tokenwise prediction error, yielding significant compute and memory reductions with dynamic, surprise-driven gating behavior over training (Wieser et al., 26 Nov 2025).