Papers
Topics
Authors
Recent
Search
2000 character limit reached

FlashMHF: Multi-Head FFN for Transformers

Updated 1 March 2026
  • FlashMHF is a multi-head feed-forward network that replaces standard FFNs in Transformers using dynamic routing and fused on-chip compute kernels.
  • It leverages dynamically weighted parallel sub-networks to balance intermediate and head dimensions, ensuring scalable expressivity and reduced memory usage.
  • Empirical evaluations show FlashMHF achieves lower perplexity, improved accuracy, faster inference, and greater hardware efficiency compared to standard SwiGLU.

Flash Multi-Head Feed-Forward Network (FlashMHF) is an architectural replacement for the position-wise feed-forward network (FFN) used in Transformer models. Motivated by the structural similarity between single-head attention and FFNs, and inspired by the benefits of multi-head mechanisms in increasing representational expressivity, FlashMHF introduces a multi-head design for FFNs. Two central innovations set FlashMHF apart: (1) an I/O-aware, fused compute kernel that computes outputs online in on-chip SRAM, and (2) a design utilizing dynamically weighted parallel sub-networks within each head to maintain a balanced ratio between intermediate and head dimensions as models scale. FlashMHF achieves improved perplexity, accuracy, and hardware efficiency while reducing peak activation memory and accelerating inference in practical settings (Zhang et al., 7 Dec 2025).

1. Motivation and Background

In the Transformer architecture, each layer includes an FFN after multi-head self-attention, usually expressed as

FFN(X)=W2ϕ(W1X+b1)+b2,\mathrm{FFN}(X) = W_2\,\phi(W_1 X + b_1) + b_2,

where XRL×dmodelX \in \mathbb{R}^{L \times d_\mathrm{model}}, ϕ\phi is a nonlinearity, W1Rdf×dmodelW_1 \in \mathbb{R}^{d_f \times d_\mathrm{model}}, W2Rdmodel×dfW_2 \in \mathbb{R}^{d_\mathrm{model} \times d_f}. Geva et al. (2020) observed the computation can be reinterpreted as attention over parameters. The analogy to attention motivates extending feed-forward layers with a multi-head mechanism to enhance expressivity, expecting the decomposition of computation into parallel subspaces to yield richer modeling power.

However, a naïve multi-head FFN (MH-FFN) approach that replicates SwiGLU-style intermediate expansions headwise runs into two critical scaling issues:

  • Memory overhead scales with the product H×(Ldf)H \times (L \cdot d_f), where HH is the number of heads and dfd_f the intermediate size.
  • As model scale increases and dfd_f grows (following scaling laws), the per-head dimension dhd_h remains fixed, leading to an imbalanced and inefficient ratio df/dhd_f/d_h that degrades both scalability and expressivity (Zhang et al., 7 Dec 2025).

2. Formal Construction and Mechanism

2.1 Standard FFN and Naïve MH-FFN

The classical FFN applies the transformation per token: FFN(x)=W2ϕ(W1x+b1)+b2.\mathrm{FFN}(x) = W_2\,\phi(W_1 x + b_1) + b_2. A naïve MH-FFN first projects the input to HH heads (each dimension dh=dmodel/Hd_h = d_\mathrm{model}/H), then applies independently-parameterized intermediate expansions (e.g., SwiGLU) to each, and finally concatenates the outputs. Formally:

  • Input projection: Q=splitH(XWin)RL×H×dhQ = \mathrm{split}_H(X W_\mathrm{in}) \in \mathbb{R}^{L \times H \times d_h}.
  • Each head uses independent SwiGLU parameters KhK^h, UhU^h, VhRdf×dhV^h \in \mathbb{R}^{d_f \times d_h}.
  • Head output: S,h=(SiLU(Q,hKh)(Q,hUh))VhS_{\ell,h} = ( \mathrm{SiLU}(Q_{\ell,h} K^{h\top}) \odot (Q_{\ell,h} U^{h\top}) ) V^h.
  • Outputs concatenated across HH heads and projected.

While expressive, this design requires storing HH independent L×dfL \times d_f activations and exacerbates the intermediate-to-head imbalance as models scale, leading to excessive memory usage and reduced functional capacity in each head.

2.2 FlashMHF: Balanced and Dynamic Multi-Head Design

FlashMHF restructures each head’s intermediate pathway into EE smaller, dynamically weighted sub-networks of width de(8/3)dhd_e \approx (8/3)\,d_h, the SwiGLU-optimal ratio. Specifically, the intermediate size dfd_f is partitioned as df=Eded_f = E \cdot d_e. Each head’s query Q,hQ_{\ell,h} feeds into:

  • A gating matrix WhRdh×EW^h \in \mathbb{R}^{d_h \times E} that computes logits PhP^h per token: Ph=Q:,hWhP^h = Q_{: ,h} W^h.
  • Gating probabilities are normalized to sum to one via sigmoid and scaling:

R,eh=σ(P,eh)eσ(P,eh)+ϵR^h_{\ell,e} = \frac{ \sigma(P^h_{\ell,e}) }{ \sum_{e'} \sigma(P^h_{\ell,e'}) + \epsilon }

  • Each sub-network ee has parameters KehK^h_e, UehU^h_e, VehRde×dhV^h_e \in \mathbb{R}^{d_e \times d_h}.
  • The final output for each head is a weighted sum over its sub-networks:

S,h=e=1ER,eh(SiLU(Q,hKeh)(Q,hUeh))VehS_{\ell, h} = \sum_{e=1}^E R^h_{\ell,e} \cdot \big( \mathrm{SiLU}(Q_{\ell,h} K^{h\top}_e) \odot (Q_{\ell,h} U^{h\top}_e) \big) V^h_e

Outputs from all heads are concatenated as usual and projected out.

This restores the optimal ratio between intermediate and head dimensions regardless of HH, enabling scalable expressivity and memory efficiency at all model scales (Zhang et al., 7 Dec 2025).

3. Fused Kernel and I/O-Aware Computation

FlashMHF’s kernel (“FlashFFN”) is architected to avoid the memory bottleneck of materializing the full L×dfL \times d_f activation tensor. Analogous to FlashAttention, it exploits SRAM by streaming computations over the dfd_f dimension in tiled blocks of size bb, with all operations for a tile (associated Km,Um,VmK_m, U_m, V_m) computed online before proceeding to the next. The pseudo-code logic is:

  • For each block m=1,,Mm=1,\ldots, M (M=deE/bM = d_e E / b):
    • Load Km,Um,VmK_m, U_m, V_m (b×dh)(b \times d_h) for the tile,
    • Compute M=QKmM = Q K_m^\top, N=QUmN = Q U_m^\top,
    • Apply SiLU and gating: A=SiLU(M)NA = \mathrm{SiLU}(M) \odot N,
    • Apply router weights RR,
    • Accumulate Oacc+=(AVm)O_\text{acc} += (A \cdot V_m).
  • Final output is written only once per head.

Each tile fits entirely in SRAM, and the total layerwise peak activation memory is reduced from O(Ldf)\mathcal{O}(L \cdot d_f) (SwiGLU) to O(Ldmodel)\mathcal{O}(L \cdot d_\mathrm{model})—a practical reduction of 3–5×\times.

Hardware implementations leverage kernel fusion in Triton and Hopper/TK, incorporating asynchronous producer–consumer staging, warp-group specialization, and multi-stage buffering for maximal bandwidth utilization on NVIDIA H100 GPUs (Zhang et al., 7 Dec 2025).

4. Dynamical Routing via Parallel Sub-Networks

A distinctive feature of FlashMHF is the dynamic, per-head, per-token routing afforded by small gating networks. For each input token, the gating mechanism determines a weight over EE sub-networks, analogous to an internal soft MoE, but without the capacity collapse or static routing issues frequently present in such modules. All sub-networks remain active and contribute to each output, ensuring full differentiability and the ability to emphasize different “reasoning sub-paths” at a fine granularity.

By tethering the sub-network width ded_e directly to dhd_h via de(8/3)dhd_e \approx (8/3)d_h, FlashMHF permits scalable parallelization and expressivity. Empirical ablations demonstrate that tying ded_e to dhd_h is critical—dense routing without the multi-head decomposition (i.e., H=1,E>1H=1, E>1) underperforms, affirming the necessity of the combined multi-head and dynamic sub-network approach (Zhang et al., 7 Dec 2025).

5. Implementation and Integration in Transformers

FlashMHF serves as a drop-in replacement for SwiGLU or standard FFN layers within the Transformer block. The data flow, including LayerNorm, residual connections, and attention mechanisms, remains unchanged upstream and downstream. The hardware memory pattern involves:

  • Each QQ head read into SRAM per tile.
  • Sub-network parameters (for K,U,VK, U, V) double-buffered and streamed from high-bandwidth memory.
  • Activation accumulation and storage performed entirely in SRAM.

This strategy further reduces memory traffic by avoiding the allocation of large intermediate tensors, particularly beneficial for long context lengths or large model deployments on memory-constrained GPUs.

6. Empirical Evaluation

Benchmarks span models ranging from 128M to 1.3B parameters, with primary evaluation on PG19 validation (60–100B token pretraining) and downstream tasks including HellaSwag, SIQA, PIQA, OBQA, Winogrande, and RACE. Key results include:

  • Consistently improved perplexity versus SwiGLU:
    • 128M: FlashMHF-128hdim achieves lower eval loss.
    • 370M: SwiGLU loss 3.030, FlashMHF (d_h=128) loss 3.014; naïve MH-FFN fails to scale.
    • 1.3B: SwiGLU loss 2.843, FlashMHF (d_h=128) loss 2.793 (Δ\Delta –0.050).
  • The highest average and per-task accuracy occurs with FlashMHF, especially for dh=128d_h=128.
  • Inference latency on NVIDIA H100 is up to 1.08×\times faster than SwiGLU, with mean speedup \sim1.05×\times; memory footprint per layer reduces by 3–5×\times across varying sequence lengths.
  • Ablations indicate dh=128d_h=128 as optimal, with dh=64d_h=64 underfitting and dh=256d_h=256 exhibiting diminishing returns due to decreasing head diversity.

7. Limitations and Trade-Offs

While FlashMHF delivers substantial memory and modest inference time improvements, several caveats are observed:

  • Kernel engineering complexity is significant, with performance hinging on careful tiling, warp group design, and latency-sensitive staging.
  • The latency advantage is less pronounced than FlashAttention due to already highly optimized baseline SwiGLU kernels.
  • Optimal HH and dhd_h must be tuned per compute regime, balancing head diversity against routing and memory overhead.
  • The approach is most advantageous for settings involving long input contexts or large models deployed on memory-constrained GPUs.
  • For very small models or scenarios where kernel launch overhead dominates, established baselines such as SwiGLU may remain preferable (Zhang et al., 7 Dec 2025).

FlashMHF establishes multi-head, dynamically routed FFNs as an expressive and hardware-efficient principle for Transformer design, providing a scalable alternative for next-generation architectures.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Flash Multi-Head Feed-Forward Network (FlashMHF).