Last Query Transformer RNN
- The architecture reduces computational complexity from O(L^2) to O(L) by using only the final input as the query in self-attention, efficiently capturing global context.
- It integrates a transformer encoder with a recurrent LSTM layer and a DNN output, optimizing sequence modeling for tasks like next-step prediction and classification.
- Empirical results, including a validation AUC of ~0.8165 and ensemble performance up to 0.820, demonstrate its practical scalability in processing sequences up to 1728 events.
A Last Query Transformer RNN is a hybrid neural architecture designed for efficient sequence modeling, specifically in scenarios where only the final output in a sequence is required, such as next-step prediction or binary classification over long history logs. Its innovation is to restrict Transformer self-attention computation to a single query—the last element in the input sequence—while all prior elements serve as keys and values. This reduces computational complexity from quadratic to linear with respect to sequence length, enabling practical processing of very long input histories and outperforming canonical transformer or RNN baselines in relevant tasks.
1. Architectural Principles and Motivation
The canonical Transformer encoder computes attention for every pair of input positions: for a sequence of length , query (), key (), and value () matrices are all ( is embedding dimension), creating an attention matrix. This quadratic scaling in both memory and computation severely limits practical sequence length in many real-world problems.
The Last Query Transformer RNN modifies this paradigm by generating only a single query vector —associated with the final input token —while and remain . The attention operation:
produces an output representation for just the last position, with complexity . This design precisely matches tasks where only the final output is needed, such as predicting the next correct answer in student knowledge tracing.
Following the modified Transformer encoder, a recurrent layer (typically LSTM, ) incorporates sequential and recency dynamics, capturing trends and dependencies beneficial for temporal classification. The final DNN layer computes the required binary or categorical prediction.
2. Mathematical Formulation and Efficiency Gains
Standard Transformer Self-Attention:
- Computation and memory complexity:
Last Query-Only Attention:
- (for last token),
- Complexity:
This reduction enables practical input sequences up to (Riiid! competition), compared to typical limits of 100–500 in canonical transformer models for educational and time-series data.
3. Model Pipeline and Implementation
| Component | Method | Role |
|---|---|---|
| Input Encoding | Categorical + continuous embed | Raw features dense vectors |
| Transformer Encoder | 1 layer, last query only | Relates last question to full history |
| RNN (LSTM) | 1 layer, | Sequential/temporal feature learning |
| Output DNN | MLP (sigmoid for binary) | Final classification |
Input features typically include categorical identifiers (e.g., question ID), elapsed time, timestamp difference, and previous answer correctness. Embeddings and feature fusion occur via a DNN. The last element's representation—post-attention—is passed to LSTM, then to the final prediction head.
Meta-architectural notes:
- Multi-head attention is used, varying the number of heads ($2, 4, 8, 16, 32$) for ensembles.
- Sequence padding and batching must ensure that only the final query is used per prediction.
- The encoder and RNN components are trained jointly via standard cross-entropy or AUC-optimized loss.
4. Performance Metrics and Empirical Findings
The architecture achieved first place in the Riiid! competition for answer correctness prediction:
Single Model Results:
- Validation AUC
- Public/Private Leaderboard: $0.816/0.818$ AUC
Ensemble:
- Five models with variable head count; private leaderboard: $0.820$ AUC
Key observations:
- Ability to process very long input sequences (up to $1728$ events) contributed most strongly to predictive performance, more than increasing model depth or size.
- Efficiency gains from linear attention computation allowed feasible training and inference—critical for competition settings with large datasets.
- The hybrid design (Transformer for context aggregation, RNN for recency/order modeling) captured both global and local dependencies.
5. Comparative and Theoretical Context
The Last Query Transformer RNN diverges from traditional transformer architectures that require quadratic complexity for full sequence-to-sequence modeling. By targeting the “last query only,” it aligns computational resources with informational relevance—where only the final output is required. This approach generalizes to a broad setting of autoregressive and next-step prediction problems.
The model retains core transformer strengths (contextual aggregation via self-attention) while harnessing RNN capabilities (temporal ordering) in an efficient, end-to-end trainable pipeline. The architecture is not limited to knowledge tracing and can be adapted to other domains, such as financial forecasting, time-series regression, or selective event detection, wherever only the final prediction from a long input history is required.
6. Practical Implications and Scaling Considerations
- When applying to datasets with long event histories, always assess whether only the final output is needed—if so, Last Query attention is dramatically more efficient.
- The architecture supports scaling to sequence lengths unattainable with conventional transformers, making it especially suitable for educational logs, clickstreams, or similar data streams.
- Ensemble variants differing in attention head count display minimal variance, further validating robustness.
- GPU memory consumption and throughput for both training and inference are improved by over an order of magnitude by the complexity reduction.
7. Generalization and Transferability
The core principle—using only the last element as the query in attention—translates to any autoregressive domain where next-step prediction, sequence classification, or rolling forecast is the central task. The hybridization with RNN enables robust modeling of recency and order information, which is preserved even with the reduced attention computation.
This architecture is foundational for modern efficient context-aware models and has direct implications for scaling up sequence modeling in domains with large histories and next-step prediction focus.
Sponsored by Paperpile, the PDF & BibTeX manager trusted by top AI labs.
Get 30 days free