FlashMHF: Multi-Head FFN in Transformers
- FlashMHF is a multi-head feed-forward network that integrates blockwise fused SRAM computation for improved scalability.
- It employs dynamically gated parallel sub-networks to maintain optimal intermediate-to-head ratios across model scales.
- Empirical evaluations demonstrate enhanced perplexity, downstream accuracy, reduced memory footprint, and modest speedup in inference.
Multi-Head Feed-Forward Networks (MH-FFN), and in particular the Flash Multi-Head FFN (FlashMHF), represent an evolution of the feed-forward architecture within Transformer models. FlashMHF is designed as a direct replacement for conventional point-wise SwiGLU FFNs in Transformer blocks, introducing a multi-head structure and blockwise fused computation inspired by the symmetry with multi-head self-attention. The central innovations are an I/O-aware kernel that computes outputs online in on-chip SRAM, thereby reducing memory pressure, and the use of dynamically gated parallel sub-networks to preserve the optimal intermediate-to-head dimensionality ratio as models scale—addressing critical scalability and expressivity bottlenecks encountered by naïve multi-head FFN designs. FlashMHF demonstrates improvements in perplexity and downstream task accuracy, significant memory reduction, and inference speedup while maintaining architectural compatibility with standard Transformer compositions (Zhang et al., 7 Dec 2025).
1. Integration into Transformer Architectures
FlashMHF is implemented as a drop-in replacement for traditional FFNs in Transformer blocks. The baseline structure,
originally computes the FFN as
with input . FlashMHF replaces this component, applying with the same input-output dimensionality and identical placement of normalization and residual connections. This ensures seamless compatibility at all integration points.
The multi-head structure mirrors multi-head attention, assigning heads each with head dimension such that . Where in attention heads receive per-head queries , keys , and values , a naïve MH-FFN would analogously compute per-head outputs
0
aggregating head outputs via concatenation and output projection. However, without architectural constraints, this approach rapidly incurs prohibitive memory and parameter inefficiencies as model width scales [(Zhang et al., 7 Dec 2025), Sect. 3.1].
2. Fused SRAM-Oriented Kernel Design
The FlashMHF core computation is dominated by a blockwise kernel intentionally structured to run within on-chip SRAM, minimizing reading/writing of large intermediate tensors from high-bandwidth memory (HBM). The operational sequence for one head and one of its sub-networks with input 1, learned parameters 2 (with 3 the sub-network width), and chunking 4 into 5 blocks of size 6, is:
7
This structure (see Eq. 9 and Algorithm A.1 in (Zhang et al., 7 Dec 2025)) ensures no full 8 intermediate activation is ever materialized off-chip. Rather, intermediate results are accumulated in SRAM registers, with only the final output 9 written to slower memory, a methodology analogous to FlashAttention’s streaming approach. The forward pass pseudocode manages batch, head, and block dimensions as described, with gating weights dynamically introduced per block.
3. Dynamically Weighted Parallel Sub-Networks
To address scaling pathologies inherent in naïve multi-head FFN splits, such as an exploding ratio 0 as 1 grows, FlashMHF partitions the intermediate FFN dimension across 2 parallel sub-networks within each head, giving 3 with 4 (5 as with SwiGLU). This enforces an optimal and consistent 6 ratio independently of head count or model scale.
Each head 7 learns a small gating matrix 8, generating per-token, per-head sub-network selection weights:
9
(0 is the elementwise sigmoid, Eq. 8). For token 1, head 2, and sub-network 3, each sub-network computes a gated SwiGLU-like output which is then mixed according to 4, and finally all per-head outputs are concatenated and projected back to 5 as standard.
4. Theoretical and Practical Resource Analysis
FlashMHF maintains the overall compute complexity of the underlying FFN transformation but dramatically alters peak memory requirements. For batch size 6 and sequence length 7:
- Standard SwiGLU FFN: peak activation memory is 8 due to storing 9.
- FlashMHF: peak memory reduces to 0; no activation of size 1 is stored off-chip, only current block-level inputs and outputs plus gating terms.
Empirically, on H100/Hopper GPUs, FlashMHF reduces peak HBM by 3–5× across a broad range of sequence lengths (192–16K tokens), and achieves up to 1.08× inference speedup on long contexts (average ~1.05×). The speedup magnitude is modest relative to FlashAttention due to the baseline efficiency of cuBLAS-based FFN implementations and FFN output bandwidth constraints, but the reduction in activation memory is substantial [(Zhang et al., 7 Dec 2025), Figs. 7a, 7b].
5. Empirical Evaluation
FlashMHF was evaluated across LLMs of three scales: ~128M, ~370M, and ~1.3B parameters. All models were trained on The Pile with context length 4096, batch size 64, and GPT-NeoX tokenization, using baseline LLaMA-style attention and SwiGLU FFNs for direct comparison. Downstream evaluations employed six common benchmarks (HellaSwag, Social IQA, Physical IQA, OpenBookQA, WinoGrande, RACE).
Key results include:
- Perplexity (PG19): At 370M, FlashMHF achieves perplexity 3.014 vs. 3.030 for baseline; at 1.3B, 2.793 vs. 2.843. The reduction (~0.85 ppl at 1.3B) is consistent across scales.
- Downstream accuracy: FlashMHF-128hdim attains 40.48% average (370M) and 43.35% (1.3B) compared to baselines of 39.92% and 41.75%, respectively.
- Efficiency: Consistent memory reduction (3–5×) and up to 8% inference speedup on long-context, deep models.
- Robustness: Naïve MH-FFN variants fail to scale past 128M model size, and ablations (e.g., ParamKV) underperform, underscoring the necessity of both relational structure and parallel sub-network design [(Zhang et al., 7 Dec 2025), Table 2, Table 3, Fig. 6].
6. Design Considerations, Limitations, and Extensions
FlashMHF introduces kernel and implementation complexity exceeding that of standard cuBLAS FFN calls, particularly on Triton and Hopper architectures. For small models or short sequence lengths, the overhead may outweigh gains. The selection of 2 (head dimension) is a crucial hyperparameter: small 3 risks under-capacity, large 4 reduces head count and diminishes representational diversity.
Potential extensions include exploring alternative gating functions (such as sparse top-5 MoE), fusing upstream operations (e.g., LayerNorm+FlashMHF) into a single kernel, and adapting the architecture to non-GPU hardware (TPU, SW). Joint optimization of 6, 7, and 8 is motivated for different scaling regimes.
A plausible implication is that multi-head partitioning is emerging as a broadly superior architectural principle in both attention and feed-forward pathways of large-scale Transformers, as it enables improved parameter and computational efficiency, more stable scaling, and substantial reduction in memory footprints, all while preserving end-to-end training and inference semantics (Zhang et al., 7 Dec 2025).
7. Summary
Flash Multi-Head FFN (FlashMHF) replaces conventional SwiGLU FFNs in Transformer blocks with a multi-head, dynamically gated, blockwise computation framework. Its design maintains the optimal intermediate-to-head dimensional relationship via parallel sub-networks, fuses computation to SRAM to eliminate off-chip intermediate tensors, and empirically improves both language modeling and downstream performance metrics. FlashMHF achieves up to 1.08× inference acceleration and 3–5× activation memory reduction, representing a scalable and efficient alternative for modern Transformer architectures (Zhang et al., 7 Dec 2025).