FusedKV: Cross-Layer Transformer Fusion
- FusedKV is a memory-efficient cross-layer key–value cache reconstruction mechanism that fuses bottom- and middle-layer caches using learnable gates to reduce memory by 50%.
- It preserves relative positional encoding (RoPE) through a per-2D symmetry constraint, ensuring accurate attention without positional drift.
- Integration with fused kernels enables efficient Transformer decoding with improved perplexity and task accuracy, balancing memory savings and compute overhead.
FusedKV is a memory-efficient cross-layer key–value (KV) cache reconstruction mechanism for Transformer decoders that performs learnable fusion of bottom- and middle-layer caches to halve memory requirements while improving or maintaining predictive accuracy. Unlike earlier cross-layer KV sharing approaches—such as YOCO and CLA—which replace top-layer cache with those from shallower layers but suffer accuracy degradation, FusedKV reconstructs the upper layers’ KV caches on the fly from the most informative source layers using a small set of learnable gates. This design also preserves relative positional encoding (RoPE) structure throughout and is compatible with hardware-efficient kernel fusion.
1. Core Formulation and Mathematical Principles
In a standard decoder with layers, FusedKV partitions layers at (for even ):
- Storage layers : Store and maintain native KV caches and .
- Reconstruction layers : Do not store native caches but reconstruct them at each decoding step.
Let denote the key and value caches of storage layer , where is the prompt length and is head dimension. For each reconstruction layer :
where are learnable, broadcasted, feature-wise gates (with per-2D symmetry for RoPE), and denotes element-wise multiplication over all entries. Reconstruction is performed entirely on top of already-rotated keys/values, with no additional key/value projections for the reconstruction layers. A symmetry constraint maintains RoPE compatibility (Lin et al., 3 Dec 2025).
2. RoPE-Preserving Cross-Layer Fusion
Transformer decoders with RoPE encode each key as per position . In FusedKV, layers store only the post-RoPE output and fuse in this space:
Crucially, the per-2D symmetry of the gates ensures that the attention logits for token and position remain functions only of , preserving relative location and avoiding positional drift—a problem in naïve fusions. No recomputation or reapplication of RoPE is needed during reconstruction, further reducing compute overhead.
3. Decoding Architecture and Implementation
The FusedKV procedure is instantiated as follows:
- For , each new token triggers query projection at all layers.
- Storage layers (): Perform usual key and value projections and cache results.
- Reconstruction layers (): Reconstruct keys and values for all positions using stored caches from layers 1 and , applying per-layer learned gates.
- During attention, the reconstructed and are used as in a standard decoder.
The approach can be efficiently implemented with kernel fusion: for every token and reconstruction layer, a fused Triton or CUDA kernel reads , applies the two gates, sums, and writes the results in a single pass, maximizing shared-memory bandwidth and minimizing extra memory I/O. Gradient flows through the fusion gates for all steps, which are learned end-to-end with the rest of the model.
4. Complexity and Memory Trade-offs
The following table summarizes memory and I/O requirements:
| Method | Cache Memory | Cache I/O per Token |
|---|---|---|
| Vanilla (MHA) | ||
| CLA / YOCO | ||
| GQA | ||
| FusedKV-Lite | ||
| FusedKV |
Both FusedKV and FusedKV-Lite cut KV memory by exactly versus vanilla, storing only bottom and middle layers’ caches. FusedKV incurs extra I/O (threefold vanilla)—one read for each of two sources per key/value—while FusedKV-Lite matches CLA/YOCO in I/O but requires no compute at inference. This overhead is usually offset by bottlenecks elsewhere and can be efficiently handled in fused-kernel implementations.
5. Empirical Results and Ablation Studies
Across model sizes (332M to 4B parameters) and datasets (FineWeb-Edu, WikiText), FusedKV achieves the following:
- KV memory reduced by exactly .
- Validation perplexity generally improves:
| Model Size | Vanilla | FusedKV | FusedKV-Lite | |--------------|-----------|-----------|--------------| | 332M | 22.85 | 22.35 | 22.78 | | 650M | 18.47 | 18.09 | 18.55 | | 1.5B | 13.67 | 13.33 | 13.45 | | 4B | 9.18 | 8.94 | N/A |
- Downstream five-shot accuracy (e.g., MMLU, HellaSwag, ARC) consistently meets or exceeds vanilla.
- Ablations on FusedKV-Lite's source layers confirm the optimal asymmetry: values from layer 1 and keys from layer ; reversing or choosing intermediate sources degrades outcomes.
- Adding learnable gates to FusedKV-Lite (“FusedKV-Lite-Learnable”) further increases accuracy over simple fixed re-use.
6. FusedKV-Lite: Lightweight Variant
FusedKV-Lite removes all fusion computation at inference. For each reconstruction layer :
This reduces KV memory by and does not increase cache I/O compared to vanilla Transformers. In practice, FusedKV-Lite yields only a modest perplexity increase (e.g., for 1.5B parameters) but preserves end-to-end throughput. It is best suited for I/O-bound deployments where minimal runtime compute is paramount.
7. Practical Integration and Implementation Guidance
To integrate FusedKV into a Transformer decoder:
- Introduce a “KV-reconstruction” hook before attention in every reconstruction layer.
- Bypass the layer’s own native key/value projections.
- At each inference step, fetch stored post-RoPE caches from layers 1 and .
- Apply per-layer, broadcasted gates using a fused GEMM-style kernel.
- Enforce the 2-D symmetry constraint on fusion parameters to avoid RoPE corruption, e.g., by doubling each value in a -vector or directly hard-tying parameters.
- Tune kernel block sizes to optimize memory access patterns and maximize GPU utilization.
- Ensure gradient flows are correctly propagated through all attention paths to the gating vectors.
Potential pitfalls include failing to maintain weight symmetry (breaking RoPE), introducing spurious gradient accumulation, and suboptimal kernel tiling leading to memory misalignment. Adhering to the recommended fused-kernel implementation with careful parameterization ensures correctness and maximal performance (Lin et al., 3 Dec 2025).
FusedKV represents a principled, general-purpose cross-layer KV fusion strategy for Transformer inference, offering precise RoPE preservation, 50% reduction in KV cache memory overhead, and, in most setups, improved perplexity and task accuracy with minimal impact on runtime throughput. Its lightweight FusedKV-Lite variant minimizes all runtime compute overheads while maintaining comparable predictive performance, making it particularly attractive for large-scale, I/O-bound inference scenarios.