Papers
Topics
Authors
Recent
Search
2000 character limit reached

MemFine: Memory-Aware MoE Scheduling

Updated 21 April 2026
  • The paper introduces MemFine, a memory-aware fine-grained scheduling framework that mitigates token route imbalances in MoE models.
  • It employs fine-grained chunked recomputation and dynamic memory-aware tuning (MACT) to significantly reduce peak memory usage and avoid OOM errors.
  • Empirical results demonstrate up to 48% reduction in activation memory and throughput improvements, enabling efficient and scalable training on memory-limited GPUs.

MemFine is a memory-aware fine-grained scheduling framework for efficient large-scale Mixture of Experts (MoE) model training under GPU memory constraints. The approach addresses the memory bottleneck induced by severe token route imbalance in MoE architectures, enabling stable, scalable training on memory-limited hardware without sacrificing throughput or accuracy (Zhao et al., 26 Nov 2025).

1. Memory Model and Constraint Analysis

MemFine begins with a precise theoretical memory model that delineates both static and activated memory components. Static memory (MstaM_{\mathrm{sta}}) aggregates storage for parameters, gradients, and optimizer buffers: Msta=Dtparai=1NSi+Dtgradi=1NSi+4Dtopti=1NSiM_{\mathrm{sta}} = D_t^{\mathrm{para}}\sum_{i=1}^{N}S_i + D_t^{\mathrm{grad}}\sum_{i=1}^{N}S_i + 4 D_t^{\mathrm{opt}}\sum_{i=1}^{N}S_i with DtD_t denoting bytes per tensor element and SiS_i the iith tensor's size.

The peak activated (intermediate) memory per transformer MoE layer is

Mact=mgtcDtb[s(5h+ahd+2kahd+en)+s(2h+2ge)]M_{\mathrm{act}} = \frac{m_g}{t c} D_t b \left[ s (5h + a h_d + 2 k_a h_d + e_n ) + s' (2h + 2g_e) \right]

where ss is input-sequence token count, ss' the number of tokens routed to the local expert GPU, hh the hidden size, aa the attention head count, Msta=Dtparai=1NSi+Dtgradi=1NSi+4Dtopti=1NSiM_{\mathrm{sta}} = D_t^{\mathrm{para}}\sum_{i=1}^{N}S_i + D_t^{\mathrm{grad}}\sum_{i=1}^{N}S_i + 4 D_t^{\mathrm{opt}}\sum_{i=1}^{N}S_i0 tensor-parallel width, Msta=Dtparai=1NSi+Dtgradi=1NSi+4Dtopti=1NSiM_{\mathrm{sta}} = D_t^{\mathrm{para}}\sum_{i=1}^{N}S_i + D_t^{\mathrm{grad}}\sum_{i=1}^{N}S_i + 4 D_t^{\mathrm{opt}}\sum_{i=1}^{N}S_i1 context-parallel size, Msta=Dtparai=1NSi+Dtgradi=1NSi+4Dtopti=1NSiM_{\mathrm{sta}} = D_t^{\mathrm{para}}\sum_{i=1}^{N}S_i + D_t^{\mathrm{grad}}\sum_{i=1}^{N}S_i + 4 D_t^{\mathrm{opt}}\sum_{i=1}^{N}S_i2 micro-batch, and Msta=Dtparai=1NSi+Dtgradi=1NSi+4Dtopti=1NSiM_{\mathrm{sta}} = D_t^{\mathrm{para}}\sum_{i=1}^{N}S_i + D_t^{\mathrm{grad}}\sum_{i=1}^{N}S_i + 4 D_t^{\mathrm{opt}}\sum_{i=1}^{N}S_i3 the activations factor. The feasibility constraint for a non-out-of-memory (OOM) run is Msta=Dtparai=1NSi+Dtgradi=1NSi+4Dtopti=1NSiM_{\mathrm{sta}} = D_t^{\mathrm{para}}\sum_{i=1}^{N}S_i + D_t^{\mathrm{grad}}\sum_{i=1}^{N}S_i + 4 D_t^{\mathrm{opt}}\sum_{i=1}^{N}S_i4 for fractional memory budget Msta=Dtparai=1NSi+Dtgradi=1NSi+4Dtopti=1NSiM_{\mathrm{sta}} = D_t^{\mathrm{para}}\sum_{i=1}^{N}S_i + D_t^{\mathrm{grad}}\sum_{i=1}^{N}S_i + 4 D_t^{\mathrm{opt}}\sum_{i=1}^{N}S_i5 on device capacity Msta=Dtparai=1NSi+Dtgradi=1NSi+4Dtopti=1NSiM_{\mathrm{sta}} = D_t^{\mathrm{para}}\sum_{i=1}^{N}S_i + D_t^{\mathrm{grad}}\sum_{i=1}^{N}S_i + 4 D_t^{\mathrm{opt}}\sum_{i=1}^{N}S_i6.

This analytical model enables rigorous “predict-then-tune” scheduling during training.

2. Fine-Grained Chunked Recomputation

To mitigate GPU OOM risks from token imbalance—where certain experts momentarily receive disproportionately large token batches—MemFine introduces a chunked computation paradigm via the Fine-grained Chunk Distribution Algorithm (FCDA).

Instead of processing the entire batch in a monolithic fashion, the input sequence is partitioned into Msta=Dtparai=1NSi+Dtgradi=1NSi+4Dtopti=1NSiM_{\mathrm{sta}} = D_t^{\mathrm{para}}\sum_{i=1}^{N}S_i + D_t^{\mathrm{grad}}\sum_{i=1}^{N}S_i + 4 D_t^{\mathrm{opt}}\sum_{i=1}^{N}S_i7 chunks along the sequence dimension. Forward computation for each chunk yields output tensors Msta=Dtparai=1NSi+Dtgradi=1NSi+4Dtopti=1NSiM_{\mathrm{sta}} = D_t^{\mathrm{para}}\sum_{i=1}^{N}S_i + D_t^{\mathrm{grad}}\sum_{i=1}^{N}S_i + 4 D_t^{\mathrm{opt}}\sum_{i=1}^{N}S_i8. In backward, recomputation is performed chunk-wise, i.e., Msta=Dtparai=1NSi+Dtgradi=1NSi+4Dtopti=1NSiM_{\mathrm{sta}} = D_t^{\mathrm{para}}\sum_{i=1}^{N}S_i + D_t^{\mathrm{grad}}\sum_{i=1}^{N}S_i + 4 D_t^{\mathrm{opt}}\sum_{i=1}^{N}S_i9, so only activations for the current chunk reside in memory at any time.

The per-chunk activation requirement thus scales as DtD_t0 of the total, significantly lowering peak memory.

3. Dynamic Memory-Aware Chunk Scheduling

Central to MemFine is MACT (Memory-Aware Chunk Tuning), a dynamic algorithm for selecting the minimal chunk count DtD_t1 that satisfies the memory constraint and preserves throughput: DtD_t2 where DtD_t3 corresponds to per-chunk footprint. A closed-form for DtD_t4, the maximal routed tokens capacity per chunk, is

DtD_t5

and the theoretical chunk count is DtD_t6 with observed tokens DtD_t7. To limit scheduling complexity, a small set of preferred chunk bin sizes (e.g., 1, 2, 4, 8, …) is used.

4. MemFine Training Algorithm

The practical scheduling mechanism is codified as follows (abridged): SiS_i5 Chunk scheduling is thus adaptively re-optimized per iteration based on observed routing, providing robust OOM avoidance.

5. Complexity, Trade-offs, and System Dynamics

Space complexity in the baseline scales as DtD_t8; MemFine with DtD_t9 chunks achieves SiS_i0, subject to diminishing returns for large SiS_i1. Time complexity grows with SiS_i2 due to chunk recomputation overhead, but empirical results indicate only minor throughput losses for SiS_i3. The memory–throughput Pareto curve reveals a sharp knee-point selectable by the MACT heuristic.

Dynamic chunk count adaptation is frequently most conservative in early training (when route imbalance is volatile), then relaxes as token distribution stabilizes, optimizing for throughput.

6. Empirical Evaluation and Impact

Experimental results on DeepSeek-V3–based MoE models with 32× 64GB NVIDIA GPUs demonstrate:

Method Static (GB) Active (GB) Total (GB) OOM?
Full Recomputation 43.0 22.9 65.9 ✔ (OOM)
MemFine, SiS_i4 43.0 3.7 46.7 No
MemFine+MACT 43.0 11.9 54.9 No
  • Activation memory is reduced by 48.03% (MemFine+MACT vs. baseline).
  • Throughput improves by 4.42% compared to full recomputation and by 18.26% versus fixed chunking.
  • This enables scaling to models otherwise infeasible on target hardware, avoiding the accuracy degradation from aggressive capacity capping seen in prior load balancing schemes.

7. Relation to Broader Memory-Efficient LLM Training

MemFine’s chunked recomputation and MACT dynamic tuning are complementary to methods for activation checkpointing, low-rank tuning (e.g., LoRA), and CPU-offloading paradigms (e.g., MEFT). Unlike parameter-efficient fine-tuning or zeroth-order optimization, MemFine’s design is tightly specialized for the token routing–induced volatility of large MoE models under hardware constraints. Its empirical advantage lies in maximizing hardware utility and maintaining training fidelity on GPU clusters with limited memory budgets (Zhao et al., 26 Nov 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to MemFine.