Papers
Topics
Authors
Recent
Search
2000 character limit reached

Spark FFN: Sparse Transformer Efficiency

Updated 13 February 2026
  • Spark FFN is a variant of feed-forward networks that employs top-k masking, statistical thresholding, and parameter-efficient gating to enforce activation sparsity in Transformers.
  • The design systematically reduces per-token FLOPs and wall-time by leveraging the lazy neuron phenomenon without compromising model quality or standard training dynamics.
  • It integrates a linear-time differentiable top-k approximation and predictor-value split to achieve hardware-friendly sparse computations in both training and inference.

Spark FFN is a feed-forward network (FFN) variant for Transformers that explicitly exploits activation sparsity by means of top-kk masking, statistical thresholding, and parameter-efficient gating. Developed in the context of the Spark Transformer architecture, Spark FFN systematically reduces computational cost in both training and inference, achieving significant FLOPs and wall-time reductions while preserving model quality and training dynamics. The core innovation lies in scalable, hardware-friendly sparsification of FFN activations and a low-cost, differentiable predictor for selecting active neurons, thereby reactivating “lazy neuron” sparsity in modern Transformer models (You et al., 7 Jun 2025).

1. Background: Standard Transformer FFN and the Lazy Neuron Phenomenon

In the canonical Transformer layer, the FFN processes each token embedding xRdmodelx\in\mathbb{R}^{d_{\text{model}}} using a two-layer MLP:

y=(xW1) σ W2y = \left(x W_1\right)~\sigma~ W_2

with W1Rdmodel×dffW_1\in\mathbb{R}^{d_{\text{model}}\times d_{\text{ff}}}, W2Rdff×dmodelW_2\in\mathbb{R}^{d_{\text{ff}}\times d_{\text{model}}}, and a nonlinearity σ()\sigma(\cdot) such as GELU or ReLU. This requires approximately 4dmodeldff4 d_{\text{model}} d_{\text{ff}} FLOPs per token.

Li et al. (2022) identified the “lazy neuron” phenomenon: when σ=\sigma = ReLU, only a small subset of the dffd_{\text{ff}} hidden units have nonzero activations per token. This intrinsic sparsity allows, in principle, for skipping W2W_2 multiplications on inactive neurons, reducing some computation but not the initial xW1xW_1 product. The challenge is to efficiently and explicitly harness this per-token sparsity without degrading model quality or increasing parameter count (You et al., 7 Jun 2025).

2. Explicit Sparsification via Top-kk Masking

Spark FFN enforces sparsity by selecting the top-kk activations from the pre-nonlinearity score vector h=xW1h = xW_1. The masking process is:

  • Compute mask=Topk(h)mask = \mathrm{Top}_k(h) where mask{0,1}dffmask \in \{0,1\}^{d_{\text{ff}}},
  • Compute hsparse=hmaskh_{\text{sparse}} = h \odot mask,
  • Propagate via y=hsparseW2y = h_{\text{sparse}} W_2.

Here, Topk(h)\mathrm{Top}_k(h) retains only the kk largest elements (per token), annihilating the rest. If kdffk \ll d_{\text{ff}}, the final multiplication hsparseW2h_{\text{sparse}} W_2 is reduced to 2kdmodel2 k d_{\text{model}} FLOPs from 2dffdmodel2 d_{\text{ff}} d_{\text{model}}. However, Topk\mathrm{Top}_k selection by sorting is O(dfflogdff)O(d_{\text{ff}}\log d_{\text{ff}}) and is non-differentiable, necessitating an efficient relaxation (You et al., 7 Jun 2025).

3. Statistical Top-kk: Linear-Time Differentiable Masking

To overcome the inefficiency of exact Top-kk, Spark FFN introduces the “statistical Top-kk” operator, which approximates Top-kk selection in linear time and is differentiable almost everywhere. For a vector hRdh \in \mathbb{R}^d and target kk:

  • Compute θ(h,k)=mean(h)+std(h)Q(1kd)\theta(h, k) = \mathrm{mean}(h) + \mathrm{std}(h) \cdot Q\left(1 - \frac{k}{d}\right), where QQ is the standard Gaussian quantile function,
  • Apply soft-thresholding: hmask=max(hθ,0)h_{\text{mask}} = \max(h - \theta, 0).

This procedure sets elements below the threshold to zero and for others subtracts θ\theta. Since mean and standard deviation are O(d)O(d) and soft-thresholding is elementwise, the total computation is O(d)O(d). Empirically, the approach ensures k\approx k surviving entries under a Gaussian fit assumption, which is supported for FFN pre-activations in practice (You et al., 7 Jun 2025).

4. Predictor-Value Decomposition and Efficient Sparse Computation

Spark FFN partitions the input and first FFN layer for further efficiency. The weight matrix W1W_1 and input xx are split:

  • Predictor block: W1,predRr×dffW_{1,\mathrm{pred}} \in \mathbb{R}^{r \times d_{\text{ff}}}, operating on xpredRrx_{\mathrm{pred}} \in \mathbb{R}^r,
  • Value block: W1,restR(dmodelr)×dffW_{1,\mathrm{rest}} \in \mathbb{R}^{(d_{\text{model}}-r) \times d_{\text{ff}}}, operating on xrestRdmodelrx_{\mathrm{rest}} \in \mathbb{R}^{d_{\text{model}}-r}.

The mechanism is:

  • Compute predictor scores: scores=xpredW1,predscores = x_{\mathrm{pred}} \cdot W_{1,\mathrm{pred}},
  • Build mask: mask=Statistical-Topk(scores)mask = \text{Statistical-Top}_k(scores),
  • Compute values: vals=W1,restxrestvals = W_{1,\mathrm{rest}}^\top x_{\mathrm{rest}},
  • Select and activate: hidden=GELU(mask)valshidden = \text{GELU}(mask) \odot vals,
  • Output: y=hiddenW2y = hidden W_2.

Because maskmask is kk-sparse, the expensive projections and matrix multiplications can be performed efficiently as sparse vector-matrix products. The predictor block’s cost is 2rdff2 r d_{\text{ff}}; value block and output cost 2k(dmodelr)2k(d_{\text{model}}-r) and 2kdmodel2k d_{\text{model}}, respectively. Optimal empirical performance occurs at rdmodel/2r \approx d_{\text{model}}/2 (You et al., 7 Jun 2025).

5. Measured Efficiency, Sparsity, and Model Quality

On a 2B-parameter Transformer pretrained according to the Gemma-2 recipe, Spark FFN achieves:

  • Activation sparsity: k/dff8%k/d_{\text{ff}} \approx 8\% (e.g., k1106k \approx 1106 with dff=13824d_{\text{ff}}=13824),
  • End-to-end per-token FLOPs reduction: 2.5×\approx2.5\times (72% in FFN, 75% in attention dot-products),
  • Decoding speedup: up to 1.79×1.79\times on 4-core CPU (prefill 1.64×1.64\times, decode 1.86×1.86\times), up to 1.40×1.40\times on NVIDIA T4 GPU,
  • No change to the optimizer, learning rate schedule, or pretraining curriculum.

Empirically, this procedure yields near-zero impact on pretraining loss or downstream quality, and enables hardware-efficient implementations using specialized sparse kernels (You et al., 7 Jun 2025).

6. Architectures, Hyper-Parameters, and Implementation

Key configurable aspects include:

  • kk: Active neurons per token (8%\approx8\% of dffd_{\text{ff}} in Gemma-2 models),
  • rr: Predictor width, optimal at rdmodel/2r\approx d_{\text{model}}/2 (multiple of $256$ per Gemma-2 constraint),
  • Thresholding: Statistical Top-kk is parameter-free, almost everywhere differentiable, and O(d)O(d),
  • Training: No modifications to standard Transformer training pipeline.

In attention, a similar predictor-value split is employed, with dattnd_{\text{attn}} halved (rattn=128r_{\text{attn}}=128), and per-token attention restricted to the top $256$ keys.

7. Pseudocode and Integration in Transformer Layers

The Spark FFN forward pass for a single token xRdmodelx \in \mathbb{R}^{d_{\text{model}}} involves:

1
2
3
4
5
6
7
8
9
10
11
12
x_pred, x_rest = x[:r], x[r:]               # Partition input
W1_pred, W1_rest, W2                        # Parameter matrices
k                                           # Sparsity target

scores = x_pred @ W1_pred                   # Predictor scores (shape: d_ff)
mu = mean(scores)
sigma = std(scores)
theta = mu + sigma * Q(1 - k/d_ff)          # Q = Gaussian quantile
mask = maximum(scores - theta, 0)           # Soft thresholding, ≈k nonzeros
vals = sparse_vector_matmul(mask > 0, x_rest, W1_rest)  # Only selected indices
hidden = GELU(mask) * vals
y = sparse_vector_matmul(hidden, W2)        # Sparse matrix multiplication

Gradients flow through the statistical Top-kk everywhere except at zero crossings. Inference reuses the same forward pass, with per-token computation dropping from 4dmodeldff4 d_{\text{model}} d_{\text{ff}} FLOPs to 2rdff+2(dmodelr)k+2kdmodel2 r d_{\text{ff}} + 2(d_{\text{model}}-r)k + 2 k d_{\text{model}}. Practical implementations employ specialized SIMD/tiling/CUDA kernels for memory and compute efficiency (You et al., 7 Jun 2025).


Spark FFN reactivates latent activation sparsity in Transformer FFNs by combining explicit top-kk masking, scalable thresholding, and efficient parameter reuse, resulting in substantial computational savings and wall-time improvements without compromising model quality or standard training dynamics.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Spark FFN.