Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
158 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
45 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

MOM: Memory-Efficient Offloaded Mini-Sequence Inference for Long Context Language Models (2504.12526v1)

Published 16 Apr 2025 in cs.LG, cs.AI, and cs.CL

Abstract: Long-context LLMs exhibit impressive performance but remain challenging to deploy due to high GPU memory demands during inference. We propose Memory-efficient Offloaded Mini-sequence Inference (MOM), a method that partitions critical layers into smaller "mini-sequences" and integrates seamlessly with KV cache offloading. Experiments on various Llama, Qwen, and Mistral models demonstrate that MOM reduces peak memory usage by over 50\% on average. On Meta-Llama-3.2-8B, MOM extends the maximum context length from 155k to 455k tokens on a single A100 80GB GPU, while keeping outputs identical and not compromising accuracy. MOM also maintains highly competitive throughput due to minimal computational overhead and efficient last-layer processing. Compared to traditional chunked prefill methods, MOM achieves a 35\% greater context length extension. More importantly, our method drastically reduces prefill memory consumption, eliminating it as the longstanding dominant memory bottleneck during inference. This breakthrough fundamentally changes research priorities, redirecting future efforts from prefill-stage optimizations to improving decode-stage residual KV cache efficiency.

Summary

  • The paper introduces MOM, a novel method that combines mini-sequence processing for MLP layers with dynamic KV cache offloading to significantly reduce GPU memory demands.
  • MOM’s approach cuts peak MLP memory usage by over 50% and extends context lengths nearly threefold on high-end GPUs by processing inputs in smaller, efficient chunks.
  • Experiments across models like Llama, Qwen, and Mistral show that MOM maintains throughput and accuracy, enabling practical inference on both high-end and consumer-grade hardware.

Processing long contexts with LLMs during inference faces significant challenges, primarily due to the high GPU memory demands. The peak memory usage typically occurs during the "prefill" stage, where the entire input sequence is processed to compute the initial key-value (KV) cache. This memory peak is dominated by the intermediate activations of the MLP (feed-forward) layers, not the attention layers (especially with optimizations like FlashAttention).

Existing methods attempt to address this. KV cache offloading moves the KV cache to CPU memory or storage, but frequent transfers can slow down decoding. Chunked prefill splits the input sequence into smaller chunks for processing, reducing peak MLP memory but often incurring overhead from repeated forward passes for each chunk, negatively impacting throughput and limiting the extent of context extension compared to theoretical maximums. The Mini-Sequence Transformer (MST) applied similar partitioning ideas for training, but it was not designed for efficient inference.

The paper "MOM: Memory-Efficient Offloaded Mini-Sequence Inference for Long Context LLMs" (2504.12526) proposes a novel approach called MOM that combines Mini-Sequence processing for MLP layers with KV cache offloading to significantly reduce GPU memory consumption during inference, thereby enabling much longer context lengths.

Here's how MOM works:

  1. Mini-Sequence Processing for MLPs: During the prefill stage, MOM partitions the input tensor internally within each MLP layer into smaller "mini-sequences." Instead of processing the entire sequence of length SS at once through the MLP (which would require intermediate memory proportional to S×IS \times I, where II is the hidden dimension size, typically $4d$), MOM processes these mini-sequences sequentially. If the input is split into MM mini-sequences, the peak intermediate memory for the MLP layers is reduced to approximately SIM\frac{S \cdot I}{M}. Crucially, when generating the first token during prefill, only the representation of the last input token is fed through the final MLP layer and the LLM head to produce the first output logit. This avoids unnecessary computation and memory usage for preceding tokens in the final layers during this critical step. For intermediate layers, the outputs of the mini-sequences are concatenated before passing to the next Transformer block.
  2. KV Cache Offloading Integration: MOM seamlessly integrates with existing KV cache offloading mechanisms (like those provided in libraries such as Hugging Face). During the prefill stage, after the attention mechanism computes the KV pairs for a block, these KV pairs are updated and can be offloaded to CPU memory. This is done dynamically, ensuring that GPU memory is primarily dedicated to the MLP computations and the necessary active data. Before the decoding stage begins, the full KV cache required for autoregressive generation is transferred back to the GPU to ensure efficient token generation without high-latency transfers per step.

By combining these two techniques, MOM effectively tackles the two main memory bottlenecks: Mini-Sequence processing drastically reduces the MLP intermediate memory peak during prefill, while offloading manages the growing KV cache size, especially for long contexts. The paper argues that this shifts the primary memory constraint from the prefill stage MLP activations to the GPU-resident KV cache needed for the decode stage.

The authors evaluated MOM on various LLM models (Llama, Qwen, Mistral) and hardware (A100 80GB, RTX 4080 12GB) and demonstrated significant improvements:

  • Memory Efficiency: MOM reduces peak GPU memory usage by over 50% on average compared to the standard inference setup. This is particularly noticeable as context length increases.
  • Context Length Extension: On an A100 80GB GPU, MOM extended the maximum context length for Llama 3.2-8B from 155,000 tokens (standard) to 455,000 tokens, a nearly threefold increase. This extension is 35% greater than achievable with conventional chunked prefill methods.
  • Throughput: MOM maintains highly competitive throughput. While offloading introduces some overhead, the Mini-Sequence processing itself has minimal computational cost and can even slightly improve speed in some cases due to better cache utilization. Compared to chunked prefill, MOM offers better Time-to-First-Token (TTFT) performance, especially for smaller chunk sizes in the baseline.
  • Accuracy: Logit equivalence tests and Needle-in-a-Haystack evaluations confirmed that MOM preserves the mathematical equivalence of the standard forward pass, resulting in identical outputs and no degradation in accuracy for tasks requiring long-context understanding.
  • Generalizability: The benefits of MOM were demonstrated across different model families (Llama, Qwen, Mistral) and were shown to be effective even on consumer-grade hardware (RTX 4080 12GB) with quantization (bitsandbytes 4-bit), making it practical for wider deployment.

The implementation of MOM is designed to be minimally invasive and compatible with existing frameworks like Hugging Face, requiring minor modifications. The authors have released their implementation publicly on GitHub.

The paper concludes by highlighting that MOM effectively eliminates the prefill MLP memory bottleneck, making the decode-stage KV cache the new dominant memory consumer. Future research directions include optimizing the integration of MOM with various inference frameworks (like vLLM or sglang) and exploring further KV cache compression techniques specifically for the decoding stage to push the boundaries of long-context inference even further.

Youtube Logo Streamline Icon: https://streamlinehq.com