Papers
Topics
Authors
Recent
Search
2000 character limit reached

Sparton: Fast and Memory-Efficient Triton Kernel for Learned Sparse Retrieval

Published 26 Mar 2026 in cs.IR | (2603.25011v1)

Abstract: State-of-the-art Learned Sparse Retrieval (LSR) models, such as Splade, typically employ a Language Modeling (LM) head to project latent hidden states into a lexically-anchored logit matrix. This intermediate matrix is subsequently transformed into a sparse lexical representation through element-wise operations (ReLU, Log1P) and max-pooling over the sequence dimension. Despite its effectiveness, the LM head creates a massive memory bottleneck due to the sheer size of the vocabulary (V), which can range from 30,000 to over 250,000 tokens in recent models. Materializing this matrix creates a significant memory bottleneck, limiting model scaling. The resulting I/O overhead between operators further throttles throughput and runtime performance. In this paper, we propose Sparton, a fast memory-efficient Triton kernel tailored for the LM head in LSR models. Sparton utilizes a fused approach that integrates the tiled matrix multiplication, ReLU, Log1P, and max-reduction into a single GPU kernel. By performing an early online reduction directly on raw logit tiles, Sparton avoids materializing the full logit matrix in memory. Our experiments demonstrate that the Sparton kernel, in isolation, achieves up to a 4.8x speedup and an order-of-magnitude reduction in peak memory usage compared to PyTorch baselines. Integrated into Splade (|V| ~ 30k), Sparton enables a 33% larger batch size and 14% faster training with no effectiveness loss. On a multilingual backbone (|V| ~ 250k), these gains jump to a 26x larger batch size and 2.5x faster training.

Summary

  • The paper introduces Sparton, which fuses matrix multiplication, element-wise operations, and reduction to drastically cut memory usage in Learned Sparse Retrieval.
  • It achieves up to 4.8× speedup and 12× memory reduction by leveraging vendor-optimized GEMM and reordering max-reduction to minimize activation storage.
  • Sparton enables processing of sequence lengths up to 8192 and larger batch sizes, unlocking scalable LSR on commodity and datacenter GPUs.

Sparton: Fast and Memory-Efficient Triton Kernel for Learned Sparse Retrieval

Motivation and Problem Context

Learned Sparse Retrieval (LSR) models have proven highly effective in producing sparse, high-dimensional representations tied to lexical vocabularies of pretrained LLMs. Architectures like Splade employ a LLM (LM) head to project hidden states into a logit matrix, which is subsequently processed through ReLU, log1p, and max-pooling operations to yield lexical salience scores. However, the LM head creates severe memory bottlenecks, particularly as vocabulary sizes in modern models (V|\mathcal{V}|) reach 30k–250k tokens. Materializing the full logit matrix LRB×S×VL \in \mathbb{R}^{B \times S \times |\mathcal{V}|} in high-bandwidth memory (HBM) leads to surges in memory usage and excessive I/O operations, drastically limiting scalability in batch size, sequence length (SS), and vocabulary, while throttling GPU throughput. Figure 1

Figure 1: LM implementations in PyTorch and Sparton. Data in HBM (grey) is loaded into SRAM per block for parallel computation in Streaming Multiprocessors (green).

PyTorch's standard and compiled execution modes fail to fuse the critical matrix multiplication (matmul) with subsequent reductions, causing repeated memory transfers and hard scalability ceilings. Recent sparse retrieval approaches, including tiling the logit matrix, ameliorate forward memory consumption but leave backward memory overhead unresolved due to how PyTorch's autograd retains intermediate activations. There is thus a pressing need for a kernel that fuses matmul, element-wise operations, and reduction into an efficient, memory-minimizing pass.

Sparton Kernel Design

Sparton addresses these bottlenecks by leveraging Triton for operator fusion and early online reduction. The forward pass exploits the monotonicity of f(x)=log(1+ReLU(x))f(x)=\log(1+\mathrm{ReLU}(x)) to reorder the max reduction before nonlinearities, enabling computation of the maxima over the sequence dimension directly on raw logits. This results in only B×VB \times |\mathcal{V}| activations being written to HBM, shrinking activation storage and bandwidth requirements by a factor of SS.

Algorithmically, Sparton computes tiled logits using vendor-optimized GEMM (cuBLAS/rocBLAS) over vocabulary tiles, instantly applying a fused Triton kernel for masked max-reduction and storing only max values and their argmax indices. Subsequent ReLU and log1p are localized to the reduced outputs, avoiding materialization of the full B×S×VB \times S \times |\mathcal{V}| matrix.

The backward pass is similarly fused. Gradients are routed exclusively to the hidden states and embeddings corresponding to the stored max sequence positions. This reduces the saved forward state from O(BSV)\mathcal{O}(BS|\mathcal{V}|) to O(BV)\mathcal{O}(B|\mathcal{V}|), radically alleviating redundant storage and minimizing I/O.

Experimental Results

Sparton was benchmarked against PyTorch (eager and compiled) and a tiled-only LM head across multiple axes: batch size, sequence length, and vocabulary size. End-to-end measurements included latency, peak memory, and retrieval effectiveness on state-of-the-art sparse encoders. Figure 2

Figure 2: Scaling Sparton (without backbone) across three dimensions: Batch Size (S=512,V=30522S=512, |\mathcal{V}|=30522), Sequence Length (B=128,V=30522B=128, |\mathcal{V}|=30522), Vocabulary Size (B=256,S=512B=256, S=512).

Key empirical findings:

  • Speedup and Memory Reduction: Sparton delivers up to 4.8×4.8\times speedup and 12×12\times peak memory reduction versus compiled PyTorch baselines.
  • Scalability: Only Sparton enables sequence lengths up to $8192$ (with B=128,V=30522B=128, |\mathcal{V}|=30522), where all baselines hit out-of-memory errors. Memory requirements remain flat as batch size, sequence length, or vocab size grows, while baselines scale linearly or worse.
  • Training Efficiency: When integrated into the Splade encoder, Sparton enables a 33% larger batch size and 14% faster training with no drop in retrieval effectiveness (measured via NDCG@10 on small-Beir). On multilingual backbones (V250k|\mathcal{V}| \approx 250k), the gains jump to a 26×26\times larger batch size and 2.5×2.5\times faster training.

Practical Implications

Sparton fundamentally removes the memory bottlenecks that have limited LSR training, unlocking the ability to process substantially larger inputs, batch sizes, sequence lengths, and vocabularies on commodity and datacenter GPUs. This enables practitioners to scale LSR models beyond the hardware limits imposed by PyTorch’s conventional implementation, facilitating rapid research iterations and deployment for retrieval scenarios requiring high throughput or multilingual coverage.

The kernel design is hardware-conscious, orchestrating block-level tiling and memory access patterns to maximize on-chip SRAM utilization and minimize HBM transfers, thus fully exploiting the arithmetic and bandwidth capabilities of modern GPUs (e.g., NVIDIA A100/H100).

Theoretical Implications and Future Directions

The operator fusion and reordering in Sparton demonstrate that activation storage for certain monotonic pipelines can be reduced by orders of magnitude. This paradigm could be extended to other model heads and architectures (beyond sparse retrieval), particularly in regimes with highly sparse regularization and sequence reductions.

Potential future developments include leveraging low-precision formats (FP8), specialized hardware features (Tensor Memory Accelerator), and exploring additional fusion opportunities across GPU primitives in Triton and vendor libraries. These optimizations would further push the envelope in LSR training speed, efficiency, and scalability.

Conclusion

Sparton presents a formalized, Triton-optimized approach for memory-efficient and accelerated LM head computation within Learned Sparse Retrieval models (2603.25011). By fusing core matrix and reduction operations, Sparton eliminates key limitations of PyTorch and enables LSR pipelines to scale seamlessly in batch, sequence, and vocabulary dimensions. Its adoption paves a practical path for scalable sparse retrieval and opens theoretical avenues for operator fusion and memory minimization in deep learning kernels.

Paper to Video (Beta)

No one has generated a video about this paper yet.

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We found no open problems mentioned in this paper.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 1 tweet with 27 likes about this paper.