Papers
Topics
Authors
Recent
Search
2000 character limit reached

FlashMHF: Multi-Head FFN in Transformers

Updated 11 June 2026
  • FlashMHF is a multi-head feed-forward network that integrates blockwise fused SRAM computation for improved scalability.
  • It employs dynamically gated parallel sub-networks to maintain optimal intermediate-to-head ratios across model scales.
  • Empirical evaluations demonstrate enhanced perplexity, downstream accuracy, reduced memory footprint, and modest speedup in inference.

Multi-Head Feed-Forward Networks (MH-FFN), and in particular the Flash Multi-Head FFN (FlashMHF), represent an evolution of the feed-forward architecture within Transformer models. FlashMHF is designed as a direct replacement for conventional point-wise SwiGLU FFNs in Transformer blocks, introducing a multi-head structure and blockwise fused computation inspired by the symmetry with multi-head self-attention. The central innovations are an I/O-aware kernel that computes outputs online in on-chip SRAM, thereby reducing memory pressure, and the use of dynamically gated parallel sub-networks to preserve the optimal intermediate-to-head dimensionality ratio as models scale—addressing critical scalability and expressivity bottlenecks encountered by naïve multi-head FFN designs. FlashMHF demonstrates improvements in perplexity and downstream task accuracy, significant memory reduction, and inference speedup while maintaining architectural compatibility with standard Transformer compositions (Zhang et al., 7 Dec 2025).

1. Integration into Transformer Architectures

FlashMHF is implemented as a drop-in replacement for traditional FFNs in Transformer blocks. The baseline structure,

LayerNormMulti-Head Self-AttentionAdd \mathrm{LayerNorm} \rightarrow \text{Multi-Head Self-Attention} \rightarrow \mathrm{Add}~%%%%0%%%%~\mathrm{Norm} \rightarrow \text{FFN} \rightarrow \mathrm{Add}~%%%%0%%%%~\mathrm{Norm},

originally computes the FFN as

FFN(X)=(XWupSiLU(XWgate))Wdown,\mathrm{FFN}(X) = \left(X W_{\text{up}} \odot \mathrm{SiLU}(X W_{\text{gate}})\right) W_{\text{down}},

with input XRB×N×dmodelX \in \mathbb{R}^{B \times N \times d_{\text{model}}}. FlashMHF replaces this component, applying O=FlashMHF(X)O = \mathrm{FlashMHF}(X) with the same input-output dimensionality and identical placement of normalization and residual connections. This ensures seamless compatibility at all integration points.

The multi-head structure mirrors multi-head attention, assigning HH heads each with head dimension dhd_h such that dmodel=Hdhd_{\text{model}} = H \cdot d_h. Where in attention heads receive per-head queries QhQ_h, keys KhK_h, and values VhV_h, a naïve MH-FFN would analogously compute per-head outputs

FFN(X)=(XWupSiLU(XWgate))Wdown,\mathrm{FFN}(X) = \left(X W_{\text{up}} \odot \mathrm{SiLU}(X W_{\text{gate}})\right) W_{\text{down}},0

aggregating head outputs via concatenation and output projection. However, without architectural constraints, this approach rapidly incurs prohibitive memory and parameter inefficiencies as model width scales [(Zhang et al., 7 Dec 2025), Sect. 3.1].

2. Fused SRAM-Oriented Kernel Design

The FlashMHF core computation is dominated by a blockwise kernel intentionally structured to run within on-chip SRAM, minimizing reading/writing of large intermediate tensors from high-bandwidth memory (HBM). The operational sequence for one head and one of its sub-networks with input FFN(X)=(XWupSiLU(XWgate))Wdown,\mathrm{FFN}(X) = \left(X W_{\text{up}} \odot \mathrm{SiLU}(X W_{\text{gate}})\right) W_{\text{down}},1, learned parameters FFN(X)=(XWupSiLU(XWgate))Wdown,\mathrm{FFN}(X) = \left(X W_{\text{up}} \odot \mathrm{SiLU}(X W_{\text{gate}})\right) W_{\text{down}},2 (with FFN(X)=(XWupSiLU(XWgate))Wdown,\mathrm{FFN}(X) = \left(X W_{\text{up}} \odot \mathrm{SiLU}(X W_{\text{gate}})\right) W_{\text{down}},3 the sub-network width), and chunking FFN(X)=(XWupSiLU(XWgate))Wdown,\mathrm{FFN}(X) = \left(X W_{\text{up}} \odot \mathrm{SiLU}(X W_{\text{gate}})\right) W_{\text{down}},4 into FFN(X)=(XWupSiLU(XWgate))Wdown,\mathrm{FFN}(X) = \left(X W_{\text{up}} \odot \mathrm{SiLU}(X W_{\text{gate}})\right) W_{\text{down}},5 blocks of size FFN(X)=(XWupSiLU(XWgate))Wdown,\mathrm{FFN}(X) = \left(X W_{\text{up}} \odot \mathrm{SiLU}(X W_{\text{gate}})\right) W_{\text{down}},6, is:

FFN(X)=(XWupSiLU(XWgate))Wdown,\mathrm{FFN}(X) = \left(X W_{\text{up}} \odot \mathrm{SiLU}(X W_{\text{gate}})\right) W_{\text{down}},7

This structure (see Eq. 9 and Algorithm A.1 in (Zhang et al., 7 Dec 2025)) ensures no full FFN(X)=(XWupSiLU(XWgate))Wdown,\mathrm{FFN}(X) = \left(X W_{\text{up}} \odot \mathrm{SiLU}(X W_{\text{gate}})\right) W_{\text{down}},8 intermediate activation is ever materialized off-chip. Rather, intermediate results are accumulated in SRAM registers, with only the final output FFN(X)=(XWupSiLU(XWgate))Wdown,\mathrm{FFN}(X) = \left(X W_{\text{up}} \odot \mathrm{SiLU}(X W_{\text{gate}})\right) W_{\text{down}},9 written to slower memory, a methodology analogous to FlashAttention’s streaming approach. The forward pass pseudocode manages batch, head, and block dimensions as described, with gating weights dynamically introduced per block.

3. Dynamically Weighted Parallel Sub-Networks

To address scaling pathologies inherent in naïve multi-head FFN splits, such as an exploding ratio XRB×N×dmodelX \in \mathbb{R}^{B \times N \times d_{\text{model}}}0 as XRB×N×dmodelX \in \mathbb{R}^{B \times N \times d_{\text{model}}}1 grows, FlashMHF partitions the intermediate FFN dimension across XRB×N×dmodelX \in \mathbb{R}^{B \times N \times d_{\text{model}}}2 parallel sub-networks within each head, giving XRB×N×dmodelX \in \mathbb{R}^{B \times N \times d_{\text{model}}}3 with XRB×N×dmodelX \in \mathbb{R}^{B \times N \times d_{\text{model}}}4 (XRB×N×dmodelX \in \mathbb{R}^{B \times N \times d_{\text{model}}}5 as with SwiGLU). This enforces an optimal and consistent XRB×N×dmodelX \in \mathbb{R}^{B \times N \times d_{\text{model}}}6 ratio independently of head count or model scale.

Each head XRB×N×dmodelX \in \mathbb{R}^{B \times N \times d_{\text{model}}}7 learns a small gating matrix XRB×N×dmodelX \in \mathbb{R}^{B \times N \times d_{\text{model}}}8, generating per-token, per-head sub-network selection weights:

XRB×N×dmodelX \in \mathbb{R}^{B \times N \times d_{\text{model}}}9

(O=FlashMHF(X)O = \mathrm{FlashMHF}(X)0 is the elementwise sigmoid, Eq. 8). For token O=FlashMHF(X)O = \mathrm{FlashMHF}(X)1, head O=FlashMHF(X)O = \mathrm{FlashMHF}(X)2, and sub-network O=FlashMHF(X)O = \mathrm{FlashMHF}(X)3, each sub-network computes a gated SwiGLU-like output which is then mixed according to O=FlashMHF(X)O = \mathrm{FlashMHF}(X)4, and finally all per-head outputs are concatenated and projected back to O=FlashMHF(X)O = \mathrm{FlashMHF}(X)5 as standard.

4. Theoretical and Practical Resource Analysis

FlashMHF maintains the overall compute complexity of the underlying FFN transformation but dramatically alters peak memory requirements. For batch size O=FlashMHF(X)O = \mathrm{FlashMHF}(X)6 and sequence length O=FlashMHF(X)O = \mathrm{FlashMHF}(X)7:

  • Standard SwiGLU FFN: peak activation memory is O=FlashMHF(X)O = \mathrm{FlashMHF}(X)8 due to storing O=FlashMHF(X)O = \mathrm{FlashMHF}(X)9.
  • FlashMHF: peak memory reduces to HH0; no activation of size HH1 is stored off-chip, only current block-level inputs and outputs plus gating terms.

Empirically, on H100/Hopper GPUs, FlashMHF reduces peak HBM by 3–5× across a broad range of sequence lengths (192–16K tokens), and achieves up to 1.08× inference speedup on long contexts (average ~1.05×). The speedup magnitude is modest relative to FlashAttention due to the baseline efficiency of cuBLAS-based FFN implementations and FFN output bandwidth constraints, but the reduction in activation memory is substantial [(Zhang et al., 7 Dec 2025), Figs. 7a, 7b].

5. Empirical Evaluation

FlashMHF was evaluated across LLMs of three scales: ~128M, ~370M, and ~1.3B parameters. All models were trained on The Pile with context length 4096, batch size 64, and GPT-NeoX tokenization, using baseline LLaMA-style attention and SwiGLU FFNs for direct comparison. Downstream evaluations employed six common benchmarks (HellaSwag, Social IQA, Physical IQA, OpenBookQA, WinoGrande, RACE).

Key results include:

  • Perplexity (PG19): At 370M, FlashMHF achieves perplexity 3.014 vs. 3.030 for baseline; at 1.3B, 2.793 vs. 2.843. The reduction (~0.85 ppl at 1.3B) is consistent across scales.
  • Downstream accuracy: FlashMHF-128hdim attains 40.48% average (370M) and 43.35% (1.3B) compared to baselines of 39.92% and 41.75%, respectively.
  • Efficiency: Consistent memory reduction (3–5×) and up to 8% inference speedup on long-context, deep models.
  • Robustness: Naïve MH-FFN variants fail to scale past 128M model size, and ablations (e.g., ParamKV) underperform, underscoring the necessity of both relational structure and parallel sub-network design [(Zhang et al., 7 Dec 2025), Table 2, Table 3, Fig. 6].

6. Design Considerations, Limitations, and Extensions

FlashMHF introduces kernel and implementation complexity exceeding that of standard cuBLAS FFN calls, particularly on Triton and Hopper architectures. For small models or short sequence lengths, the overhead may outweigh gains. The selection of HH2 (head dimension) is a crucial hyperparameter: small HH3 risks under-capacity, large HH4 reduces head count and diminishes representational diversity.

Potential extensions include exploring alternative gating functions (such as sparse top-HH5 MoE), fusing upstream operations (e.g., LayerNorm+FlashMHF) into a single kernel, and adapting the architecture to non-GPU hardware (TPU, SW). Joint optimization of HH6, HH7, and HH8 is motivated for different scaling regimes.

A plausible implication is that multi-head partitioning is emerging as a broadly superior architectural principle in both attention and feed-forward pathways of large-scale Transformers, as it enables improved parameter and computational efficiency, more stable scaling, and substantial reduction in memory footprints, all while preserving end-to-end training and inference semantics (Zhang et al., 7 Dec 2025).

7. Summary

Flash Multi-Head FFN (FlashMHF) replaces conventional SwiGLU FFNs in Transformer blocks with a multi-head, dynamically gated, blockwise computation framework. Its design maintains the optimal intermediate-to-head dimensional relationship via parallel sub-networks, fuses computation to SRAM to eliminate off-chip intermediate tensors, and empirically improves both language modeling and downstream performance metrics. FlashMHF achieves up to 1.08× inference acceleration and 3–5× activation memory reduction, representing a scalable and efficient alternative for modern Transformer architectures (Zhang et al., 7 Dec 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Multi-Head FFN (FlashMHF).