FlashMHF: Multi-Head FFN for Transformers
- FlashMHF is a multi-head feed-forward network that replaces standard FFNs in Transformers using dynamic routing and fused on-chip compute kernels.
- It leverages dynamically weighted parallel sub-networks to balance intermediate and head dimensions, ensuring scalable expressivity and reduced memory usage.
- Empirical evaluations show FlashMHF achieves lower perplexity, improved accuracy, faster inference, and greater hardware efficiency compared to standard SwiGLU.
Flash Multi-Head Feed-Forward Network (FlashMHF) is an architectural replacement for the position-wise feed-forward network (FFN) used in Transformer models. Motivated by the structural similarity between single-head attention and FFNs, and inspired by the benefits of multi-head mechanisms in increasing representational expressivity, FlashMHF introduces a multi-head design for FFNs. Two central innovations set FlashMHF apart: (1) an I/O-aware, fused compute kernel that computes outputs online in on-chip SRAM, and (2) a design utilizing dynamically weighted parallel sub-networks within each head to maintain a balanced ratio between intermediate and head dimensions as models scale. FlashMHF achieves improved perplexity, accuracy, and hardware efficiency while reducing peak activation memory and accelerating inference in practical settings (Zhang et al., 7 Dec 2025).
1. Motivation and Background
In the Transformer architecture, each layer includes an FFN after multi-head self-attention, usually expressed as
where , is a nonlinearity, , . Geva et al. (2020) observed the computation can be reinterpreted as attention over parameters. The analogy to attention motivates extending feed-forward layers with a multi-head mechanism to enhance expressivity, expecting the decomposition of computation into parallel subspaces to yield richer modeling power.
However, a naïve multi-head FFN (MH-FFN) approach that replicates SwiGLU-style intermediate expansions headwise runs into two critical scaling issues:
- Memory overhead scales with the product , where is the number of heads and the intermediate size.
- As model scale increases and grows (following scaling laws), the per-head dimension remains fixed, leading to an imbalanced and inefficient ratio that degrades both scalability and expressivity (Zhang et al., 7 Dec 2025).
2. Formal Construction and Mechanism
2.1 Standard FFN and Naïve MH-FFN
The classical FFN applies the transformation per token: A naïve MH-FFN first projects the input to heads (each dimension ), then applies independently-parameterized intermediate expansions (e.g., SwiGLU) to each, and finally concatenates the outputs. Formally:
- Input projection: .
- Each head uses independent SwiGLU parameters , , .
- Head output: .
- Outputs concatenated across heads and projected.
While expressive, this design requires storing independent activations and exacerbates the intermediate-to-head imbalance as models scale, leading to excessive memory usage and reduced functional capacity in each head.
2.2 FlashMHF: Balanced and Dynamic Multi-Head Design
FlashMHF restructures each head’s intermediate pathway into smaller, dynamically weighted sub-networks of width , the SwiGLU-optimal ratio. Specifically, the intermediate size is partitioned as . Each head’s query feeds into:
- A gating matrix that computes logits per token: .
- Gating probabilities are normalized to sum to one via sigmoid and scaling:
- Each sub-network has parameters , , .
- The final output for each head is a weighted sum over its sub-networks:
Outputs from all heads are concatenated as usual and projected out.
This restores the optimal ratio between intermediate and head dimensions regardless of , enabling scalable expressivity and memory efficiency at all model scales (Zhang et al., 7 Dec 2025).
3. Fused Kernel and I/O-Aware Computation
FlashMHF’s kernel (“FlashFFN”) is architected to avoid the memory bottleneck of materializing the full activation tensor. Analogous to FlashAttention, it exploits SRAM by streaming computations over the dimension in tiled blocks of size , with all operations for a tile (associated ) computed online before proceeding to the next. The pseudo-code logic is:
- For each block ():
- Load for the tile,
- Compute , ,
- Apply SiLU and gating: ,
- Apply router weights ,
- Accumulate .
- Final output is written only once per head.
Each tile fits entirely in SRAM, and the total layerwise peak activation memory is reduced from (SwiGLU) to —a practical reduction of 3–5.
Hardware implementations leverage kernel fusion in Triton and Hopper/TK, incorporating asynchronous producer–consumer staging, warp-group specialization, and multi-stage buffering for maximal bandwidth utilization on NVIDIA H100 GPUs (Zhang et al., 7 Dec 2025).
4. Dynamical Routing via Parallel Sub-Networks
A distinctive feature of FlashMHF is the dynamic, per-head, per-token routing afforded by small gating networks. For each input token, the gating mechanism determines a weight over sub-networks, analogous to an internal soft MoE, but without the capacity collapse or static routing issues frequently present in such modules. All sub-networks remain active and contribute to each output, ensuring full differentiability and the ability to emphasize different “reasoning sub-paths” at a fine granularity.
By tethering the sub-network width directly to via , FlashMHF permits scalable parallelization and expressivity. Empirical ablations demonstrate that tying to is critical—dense routing without the multi-head decomposition (i.e., ) underperforms, affirming the necessity of the combined multi-head and dynamic sub-network approach (Zhang et al., 7 Dec 2025).
5. Implementation and Integration in Transformers
FlashMHF serves as a drop-in replacement for SwiGLU or standard FFN layers within the Transformer block. The data flow, including LayerNorm, residual connections, and attention mechanisms, remains unchanged upstream and downstream. The hardware memory pattern involves:
- Each head read into SRAM per tile.
- Sub-network parameters (for ) double-buffered and streamed from high-bandwidth memory.
- Activation accumulation and storage performed entirely in SRAM.
This strategy further reduces memory traffic by avoiding the allocation of large intermediate tensors, particularly beneficial for long context lengths or large model deployments on memory-constrained GPUs.
6. Empirical Evaluation
Benchmarks span models ranging from 128M to 1.3B parameters, with primary evaluation on PG19 validation (60–100B token pretraining) and downstream tasks including HellaSwag, SIQA, PIQA, OBQA, Winogrande, and RACE. Key results include:
- Consistently improved perplexity versus SwiGLU:
- 128M: FlashMHF-128hdim achieves lower eval loss.
- 370M: SwiGLU loss 3.030, FlashMHF (d_h=128) loss 3.014; naïve MH-FFN fails to scale.
- 1.3B: SwiGLU loss 2.843, FlashMHF (d_h=128) loss 2.793 ( –0.050).
- The highest average and per-task accuracy occurs with FlashMHF, especially for .
- Inference latency on NVIDIA H100 is up to 1.08 faster than SwiGLU, with mean speedup 1.05; memory footprint per layer reduces by 3–5 across varying sequence lengths.
- Ablations indicate as optimal, with underfitting and exhibiting diminishing returns due to decreasing head diversity.
7. Limitations and Trade-Offs
While FlashMHF delivers substantial memory and modest inference time improvements, several caveats are observed:
- Kernel engineering complexity is significant, with performance hinging on careful tiling, warp group design, and latency-sensitive staging.
- The latency advantage is less pronounced than FlashAttention due to already highly optimized baseline SwiGLU kernels.
- Optimal and must be tuned per compute regime, balancing head diversity against routing and memory overhead.
- The approach is most advantageous for settings involving long input contexts or large models deployed on memory-constrained GPUs.
- For very small models or scenarios where kernel launch overhead dominates, established baselines such as SwiGLU may remain preferable (Zhang et al., 7 Dec 2025).
FlashMHF establishes multi-head, dynamically routed FFNs as an expressive and hardware-efficient principle for Transformer design, providing a scalable alternative for next-generation architectures.