Papers
Topics
Authors
Recent
Search
2000 character limit reached

TetraJet-v2: Efficient 4-Bit LLM Training

Updated 2 May 2026
  • TetraJet-v2 is an end-to-end fully-quantized training method for LLMs that leverages the NVFP4 4-bit format to quantize activations, weights, and gradients.
  • It introduces an unbiased double-block quantization scheme alongside OsciReset and OutControl to mitigate weight oscillation and manage structural outliers.
  • Empirical evaluations show that TetraJet-v2 substantially closes the performance gap with full precision while outperforming previous FP4 methods.

TetraJet-v2 is an end-to-end fully-quantized training (FQT) method for LLMs that achieves near-lossless accuracy using the NVFP4 4-bit floating-point format for activations, weights, and gradients across all linear layers. The method addresses key obstacles in low-precision training—specifically, weight oscillation and structural outliers—by introducing an unbiased double-block quantization scheme, an oscillation suppression mechanism (OsciReset), and an outlier retention strategy (OutControl) (Chen et al., 31 Oct 2025).

1. NVFP4 Numerical Format and Double-Block Quantization

The NVFP4 format encodes each value in four bits using the E2M1 floating-point scheme: 1 sign bit, 2 exponent bits, and 1 mantissa bit, with the set of representable values {0,±0.5,±1,±1.5,±2,±3,±4,±6}\{0, \pm0.5, \pm1, \pm1.5, \pm2, \pm3, \pm4, \pm6\}. Data are grouped into inner blocks of 16 elements, each associated with an E4M3 (8-bit) scaling factor that spans roughly [448,448][-448, 448]. Encoding operates as follows:

  • Given an outer block of 128 elements, compute a global scale Sglobal=maxiXi/(4486)S_{global} = \max_i |X_i| / (448 \cdot 6).
  • For each inner sub-block kk ($16k$ to $16k+15$), compute Sblockk=maxi[16k,16k+15]Xi/Sglobal/6S_{block_k} = \max_{i \in [16k,\, 16k+15]} |X_i/S_{global}| / 6.
  • Each element XiX_i is quantized as Pi=roundFP4(Xi/(SglobalSblocki/16))P_i = \text{round}_{\text{FP4}}(X_i / (S_{global} \cdot S_{block_{\lfloor i/16 \rfloor}})).
  • Decoding reconstructs XiPiSglobalSblocki/16X_i \approx P_i \cdot S_{global} \cdot S_{block_{\lfloor i/16 \rfloor}}.

Rounding strategy is deterministic round-to-nearest (RTN) in the forward path for activations and weights; for gradients in the backward pass, stochastic rounding is used, ensuring unbiasedness: [448,448][-448, 448]0.

2. Quantization Across Linear Layers

TetraJet-v2 applies unbiased double-block NVFP4 quantization to every operand in the three matrix multiplications (MMs) per Transformer linear layer:

  • Forward: [448,448][-448, 448]1, both [448,448][-448, 448]2 and [448,448][-448, 448]3 deterministically quantized (denoted [448,448][-448, 448]4): [448,448][-448, 448]5, [448,448][-448, 448]6, [448,448][-448, 448]7.
  • Backward w.r.t. [448,448][-448, 448]8: [448,448][-448, 448]9, using stochastically quantized gradients and weights (Sglobal=maxiXi/(4486)S_{global} = \max_i |X_i| / (448 \cdot 6)0): Sglobal=maxiXi/(4486)S_{global} = \max_i |X_i| / (448 \cdot 6)1, Sglobal=maxiXi/(4486)S_{global} = \max_i |X_i| / (448 \cdot 6)2, Sglobal=maxiXi/(4486)S_{global} = \max_i |X_i| / (448 \cdot 6)3.
  • Backward w.r.t. Sglobal=maxiXi/(4486)S_{global} = \max_i |X_i| / (448 \cdot 6)4: Sglobal=maxiXi/(4486)S_{global} = \max_i |X_i| / (448 \cdot 6)5, with Sglobal=maxiXi/(4486)S_{global} = \max_i |X_i| / (448 \cdot 6)6 and Sglobal=maxiXi/(4486)S_{global} = \max_i |X_i| / (448 \cdot 6)7 stochastically quantized: Sglobal=maxiXi/(4486)S_{global} = \max_i |X_i| / (448 \cdot 6)8, Sglobal=maxiXi/(4486)S_{global} = \max_i |X_i| / (448 \cdot 6)9 reused from the forward pass, kk0.

Stochastic quantization in the backward path guarantees that gradient estimates remain unbiased in expectation, ensuring stability of stochastic gradient descent.

3. Oscillation Detection and OsciReset

Weight oscillation arises when quantized weights “hop” excessively between bins despite negligible true weight movement, impairing convergence. TetraJet-v2 defines:

  • kk1 (master weight movement)
  • kk2 (quantized weight movement)
  • kk3

Whenever kk4 (typically kk5), the master weight is reset to the nearest representable quantized value, eliminating accumulated drift. OsciReset operates periodically (e.g., every 200 steps, after accumulating statistics over ~50 steps), and only after learning rate decays below a threshold. This mechanism is activated after kk6–kk7 of total training steps.

4. Outlier Channel Handling Via OutControl

A critical insight in TetraJet-v2 is the presence of persistent structural outliers: 5–10% of activation channels exhibiting anomalously large magnitudes. OutControl statically identifies outlier channels by ranking the channelwise kk8 norms over a small calibration set to select the top kk9 as $16k$0.

  • Forward: The activation matrix $16k$1 is split: $16k$2. Non-outlier activations are quantized and matrix-multiplied, while outliers are kept at higher precision (FP8/BF16):

$16k$3

  • Backward: The same split is applied to $16k$4 and $16k$5; non-outlier regions are stochastically quantized, outliers retained at high precision. Optionally, a Random Hadamard Transform can be applied to further reduce group-level outlier variance.

OutControl preserves the accuracy of critical channels at a negligible cost due to their small proportion.

5. Empirical Evaluation

TetraJet-v2 was evaluated on the OLMo-2 LLM family with 70M, 150M, and 370M non-embedding parameters, using the OLMo-2-Mix-1124 dataset (C4 and The Pile), up to 212B tokens (batch size 1024, sequence length 4096). AdamW optimizer and cosine decay with linear LR warm-up were employed. All activations, weights, and gradients in every linear layer are quantized to NVFP4.

Baselines included Quartet (MXFP4) and the NVIDIA NVFP4 recipe (which retains partial BF16 precision).

Method Train PPL (70M/150M/370M) Val PPL (70M/150M/370M)
BF16 (full prec) 35.95 / 26.38 / 18.70 45.27 / 33.49 / 23.70
Quartet (MXFP4) 40.77 / 29.25 / 20.76 51.23 / 36.89 / 26.16
NVIDIA recipe 40.50 / 29.18 / 20.75 50.94 / 36.73 / 26.20
TetraJet-v2-base 39.26 / 28.39 / 20.23 49.33 / 35.88 / 25.50
TetraJet-v2-full 38.08 / 27.58 / 19.89 47.75 / 34.95 / 25.11

On downstream zero-shot tasks (370M, 200B tokens), TetraJet-v2 yields average accuracy 43.60, WikiText-103 PPL 18.06, and Pile PPL 12.81, closing 51.3% of the FP4→BF16 performance gap relative to prior state of the art.

6. Best Practices for Implementation

  • Quantize all three matrix multiplications in each linear layer, using $16k$6 row and $16k$7 column grouping to match NVIDIA hardware optimally.
  • Apply deterministic rounding in the forward path (activations, weights); use stochastic rounding for all gradients in backward passes.
  • Limit the application of the Random Hadamard Transform to backward matrix multiplies ($16k$8 and $16k$9) to reduce group-level outlier effects.
  • Deploy OsciReset after substantial learning rate decay, accumulating oscillation statistics every $16k+15$050 steps and resetting every $16k+15$1200 steps, thresholded at $16k+15$2–$16k+15$3.
  • Predetermine the outlier channel set (5–10%) at initialization, retaining only these channels in FP8/BF16 for both forward and backward computation.
  • Integrate TetraJet-v2 as drop-in wrappers for standard linear modules in PyTorch or TensorFlow.

7. Synthesis and Significance

TetraJet-v2 integrates unbiased double-block NVFP4 quantization, weight oscillation suppression, and selective high-precision outlier retention to enable nearly lossless 4-bit pretraining for LLMs across all linear layers. This approach halves the performance gap between full-precision and prior FP4 methods, establishing a new baseline for scalable, efficient, low-precision LLM training (Chen et al., 31 Oct 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to TetraJet-v2.