Papers
Topics
Authors
Recent
Search
2000 character limit reached

RevFFN: Memory-Efficient Fine-Tuning for MoE LLMs

Updated 16 January 2026
  • The paper presents RevFFN, a memory-efficient paradigm that reduces peak VRAM by ~49% using reversible Transformer blocks.
  • It details a reversible MoE construction that enables exact input reconstruction during backpropagation, eliminating the need to store intermediate activations.
  • Empirical results demonstrate that RevFFN maintains or slightly improves task performance while lowering peak VRAM from 65.4GB to 39.5GB.

RevFFN is a memory-efficient paradigm for full-parameter fine-tuning of Mixture-of-Experts (MoE) LLMs utilizing reversible Transformer block architectures. It addresses the activation memory bottleneck inherent in conventional fine-tuning approaches by enabling input reconstruction from outputs during the backward pass, thereby eliminating the need to store intermediate activations. This mechanism significantly reduces peak VRAM requirements and enables single-GPU training for large-scale MoE LLMs without sacrificing expressive capacity or downstream performance (Liu et al., 24 Dec 2025).

1. Architectural Principles

1.1 Standard MoE Transformer Layer

Traditional Transformer decoder layers consist of two sublayers: multi-head self-attention with residual connections and a feed-forward network (FFN). The residual formulation is:

  • Self-attention: H′=H+Attn(LN(H),LN(H),LN(H))H' = H + \mathrm{Attn}(\mathrm{LN}(H), \mathrm{LN}(H), \mathrm{LN}(H))
  • Feed-forward: Hout=H′+FFN(LN(H′))H_{out} = H' + \mathrm{FFN}(\mathrm{LN}(H')), with FFN(x)=W2σ(W1x)\mathrm{FFN}(x) = W_2 \sigma(W_1 x)

The MoE variant replaces the FFN with a sparsely-gated expert layer. A gating network g(x)=softmax(Wgx)∈REg(x) = \mathrm{softmax}(W_g x) \in \mathbb{R}^E assigns each token to the top-kk experts, with individual two-layer MLPs per expert Ee(x)=W2(e)σ(W1(e)x)E_e(x) = W_2^{(e)} \sigma(W_1^{(e)} x), aggregated as FMoE(x)=∑e=1Ege(x)Ee(x)F_{MoE}(x) = \sum_{e=1}^E g_e(x) E_e(x). All computation occurs in the full model dimension dmodeld_{model}.

1.2 Reversibility in Residual Blocks

Conventional residual blocks require caching input activations for backpropagation, incurring O(LBSdmodel)O(LBSd_{model}) memory for LL layers, batch size Hout=H′+FFN(LN(H′))H_{out} = H' + \mathrm{FFN}(\mathrm{LN}(H'))0, sequence length Hout=H′+FFN(LN(H′))H_{out} = H' + \mathrm{FFN}(\mathrm{LN}(H'))1, and model dimension Hout=H′+FFN(LN(H′))H_{out} = H' + \mathrm{FFN}(\mathrm{LN}(H'))2. A reversible block implements a bijective mapping Hout=H′+FFN(LN(H′))H_{out} = H' + \mathrm{FFN}(\mathrm{LN}(H'))3, enabling input reconstruction during the backward pass:

  • Forward: Hout=H′+FFN(LN(H′))H_{out} = H' + \mathrm{FFN}(\mathrm{LN}(H'))4
  • Inverse: Hout=H′+FFN(LN(H′))H_{out} = H' + \mathrm{FFN}(\mathrm{LN}(H'))5

RevFFN employs a two-stream coupling method, splitting activations into halves and ensuring exact invertibility.

2. Reversible MoE Block Construction

2.1 Formulation

The hidden tensor Hout=H′+FFN(LN(H′))H_{out} = H' + \mathrm{FFN}(\mathrm{LN}(H'))6 is partitioned as Hout=H′+FFN(LN(H′))H_{out} = H' + \mathrm{FFN}(\mathrm{LN}(H'))7 with Hout=H′+FFN(LN(H′))H_{out} = H' + \mathrm{FFN}(\mathrm{LN}(H'))8. The reversible update equations for a decoder layer are:

  1. Hout=H′+FFN(LN(H′))H_{out} = H' + \mathrm{FFN}(\mathrm{LN}(H'))9
  2. FFN(x)=W2σ(W1x)\mathrm{FFN}(x) = W_2 \sigma(W_1 x)0
  3. FFN(x)=W2σ(W1x)\mathrm{FFN}(x) = W_2 \sigma(W_1 x)1

Inverse mapping is defined as:

  1. FFN(x)=W2σ(W1x)\mathrm{FFN}(x) = W_2 \sigma(W_1 x)2
  2. FFN(x)=W2σ(W1x)\mathrm{FFN}(x) = W_2 \sigma(W_1 x)3

A single fixed-point iteration initialized at FFN(x)=W2σ(W1x)\mathrm{FFN}(x) = W_2 \sigma(W_1 x)4 achieves machine-precision convergence.

2.2 MoE Feed-Forward Layer Structure

For FFN(x)=W2σ(W1x)\mathrm{FFN}(x) = W_2 \sigma(W_1 x)5:

  • Routing: FFN(x)=W2σ(W1x)\mathrm{FFN}(x) = W_2 \sigma(W_1 x)6, FFN(x)=W2σ(W1x)\mathrm{FFN}(x) = W_2 \sigma(W_1 x)7
  • Experts: FFN(x)=W2σ(W1x)\mathrm{FFN}(x) = W_2 \sigma(W_1 x)8
  • Aggregation: FFN(x)=W2σ(W1x)\mathrm{FFN}(x) = W_2 \sigma(W_1 x)9

To maintain compatibility with pre-trained MoE modules, inputs are projected via adapter matrices g(x)=softmax(Wgx)∈REg(x) = \mathrm{softmax}(W_g x) \in \mathbb{R}^E0 and g(x)=softmax(Wgx)∈REg(x) = \mathrm{softmax}(W_g x) \in \mathbb{R}^E1, yielding g(x)=softmax(Wgx)∈REg(x) = \mathrm{softmax}(W_g x) \in \mathbb{R}^E2.

3. Memory Savings and Activation Reconstruction

3.1 Back-Propagation Strategy

Standard layers require storing g(x)=softmax(Wgx)∈REg(x) = \mathrm{softmax}(W_g x) \in \mathbb{R}^E3, and LayerNorm inputs for gradient calculations. In RevFFN, the backward pass proceeds as:

  1. Reconstruct g(x)=softmax(Wgx)∈REg(x) = \mathrm{softmax}(W_g x) \in \mathbb{R}^E4 from g(x)=softmax(Wgx)∈REg(x) = \mathrm{softmax}(W_g x) \in \mathbb{R}^E5 using the inverse mapping.
  2. Re-execute LayerNorm, Attention, and MoE blocks to recreate required intermediates.
  3. Compute gradients with respect to parameters and inputs using chain-rule.

Each sublayer is executed twice per step (forward and backward), trading memory savings for compute overhead.

3.2 Memory Complexity

Method Memory Complexity
Standard fine-tuning g(x)=softmax(Wgx)∈REg(x) = \mathrm{softmax}(W_g x) \in \mathbb{R}^E6
RevFFN reversible architecture g(x)=softmax(Wgx)∈REg(x) = \mathrm{softmax}(W_g x) \in \mathbb{R}^E7

RevFFN eliminates the dependence on layer count g(x)=softmax(Wgx)∈REg(x) = \mathrm{softmax}(W_g x) \in \mathbb{R}^E8—activation memory scales only with batch size, sequence length, and model dimension.

4. Training Modifications and Performance

4.1 Backward Hook Implementation

RevFFN requires a custom backward hook:

  • At each reversible block, output activations g(x)=softmax(Wgx)∈REg(x) = \mathrm{softmax}(W_g x) \in \mathbb{R}^E9 are popped.
  • Inverse mapping reconstructs kk0.
  • Forward sublayers are re-executed to materialize intermediates (LayerNorm, Attn, MoE).
  • Autograd applies the chain-rule for gradients with respect to model parameters and inputs.

Only expert parameters kk1 and adapters kk2 are updated; the gating network remains frozen during fine-tuning.

4.2 Computational Overhead

Each reversible layer incurs roughly 2kk3 FLOPs compared to the standard layer (due to re-execution in backward). In practice, MoE computation is dominant, resulting in kk4 20–30% training overhead:

  • Throughput drops from 31.0 to 24.6 samples/s on NVIDIA H800.
Method Peak VRAM (GB) Throughput (samples/s)
SFT + Checkpointing 65.4 19.7
GaLore 45.1 35.2
RevFFN 39.5 24.6

5. Empirical Validation

Downstream task performance is evaluated on MMLU, GSM8K, MT-Bench, and a multilingual benchmark:

Method MMLU GSM8K Multilingual MT-Bench
SFT + Checkpointing 66.1% 74.8% 39.5% 7.52
RevFFN 66.7% 75.1% 38.8% 7.65

RevFFN provides a kk5 49% reduction in peak memory (vs. SFT+Checkpointing), with task accuracy matching or slightly exceeding baseline methods. Ablation studies confirm that both stages of the two-stage schedule are essential for training stability and optimal performance.

6. Usage Scenarios and Implementation Recommendations

6.1 Application Context

RevFFN is indicated in settings with VRAM constraints under kk680GB and model scales ranging from several to tens of billions of parameters. It is suited for scenarios where full-parameter adaptation is required and multi-GPU or CPU offloading is infeasible or suboptimal.

6.2 PyTorch Implementation

The reversible block is constructed as follows:

Ee(x)=W2(e)σ(W1(e)x)E_e(x) = W_2^{(e)} \sigma(W_1^{(e)} x)0

Backward hooks must be registered to:

  1. Pop output activations,
  2. Run inverse mapping,
  3. Re-execute submodules to recover intermediates,
  4. Apply autograd for parameter gradients.

7. Summary and Implications

RevFFN refactors Transformer decoder layers into reversible two-stream blocks, maintaining full MoE routing and expert computation across the model dimension. It achieves comparable downstream task performance to standard full fine-tuning while providing significant reductions in peak memory usage—enabling practical, single-GPU training of billion-parameter MoE LLMs by trading a moderate computational overhead (%%%%47FFN(x)=W2σ(W1x)\mathrm{FFN}(x) = W_2 \sigma(W_1 x)48%%%% per block) for an effective 2kk9 reduction in activation memory (Liu et al., 24 Dec 2025). This suggests broader applicability of reversible computation techniques for efficient adaptation of modern LLM architectures lacking distributed infrastructure.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Reversible FFN for MoE LLMs.