TetraJet-v2: Efficient 4-Bit LLM Training
- TetraJet-v2 is an end-to-end fully-quantized training method for LLMs that leverages the NVFP4 4-bit format to quantize activations, weights, and gradients.
- It introduces an unbiased double-block quantization scheme alongside OsciReset and OutControl to mitigate weight oscillation and manage structural outliers.
- Empirical evaluations show that TetraJet-v2 substantially closes the performance gap with full precision while outperforming previous FP4 methods.
TetraJet-v2 is an end-to-end fully-quantized training (FQT) method for LLMs that achieves near-lossless accuracy using the NVFP4 4-bit floating-point format for activations, weights, and gradients across all linear layers. The method addresses key obstacles in low-precision training—specifically, weight oscillation and structural outliers—by introducing an unbiased double-block quantization scheme, an oscillation suppression mechanism (OsciReset), and an outlier retention strategy (OutControl) (Chen et al., 31 Oct 2025).
1. NVFP4 Numerical Format and Double-Block Quantization
The NVFP4 format encodes each value in four bits using the E2M1 floating-point scheme: 1 sign bit, 2 exponent bits, and 1 mantissa bit, with the set of representable values . Data are grouped into inner blocks of 16 elements, each associated with an E4M3 (8-bit) scaling factor that spans roughly . Encoding operates as follows:
- Given an outer block of 128 elements, compute a global scale .
- For each inner sub-block ($16k$ to $16k+15$), compute .
- Each element is quantized as .
- Decoding reconstructs .
Rounding strategy is deterministic round-to-nearest (RTN) in the forward path for activations and weights; for gradients in the backward pass, stochastic rounding is used, ensuring unbiasedness: 0.
2. Quantization Across Linear Layers
TetraJet-v2 applies unbiased double-block NVFP4 quantization to every operand in the three matrix multiplications (MMs) per Transformer linear layer:
- Forward: 1, both 2 and 3 deterministically quantized (denoted 4): 5, 6, 7.
- Backward w.r.t. 8: 9, using stochastically quantized gradients and weights (0): 1, 2, 3.
- Backward w.r.t. 4: 5, with 6 and 7 stochastically quantized: 8, 9 reused from the forward pass, 0.
Stochastic quantization in the backward path guarantees that gradient estimates remain unbiased in expectation, ensuring stability of stochastic gradient descent.
3. Oscillation Detection and OsciReset
Weight oscillation arises when quantized weights “hop” excessively between bins despite negligible true weight movement, impairing convergence. TetraJet-v2 defines:
- 1 (master weight movement)
- 2 (quantized weight movement)
- 3
Whenever 4 (typically 5), the master weight is reset to the nearest representable quantized value, eliminating accumulated drift. OsciReset operates periodically (e.g., every 200 steps, after accumulating statistics over ~50 steps), and only after learning rate decays below a threshold. This mechanism is activated after 6–7 of total training steps.
4. Outlier Channel Handling Via OutControl
A critical insight in TetraJet-v2 is the presence of persistent structural outliers: 5–10% of activation channels exhibiting anomalously large magnitudes. OutControl statically identifies outlier channels by ranking the channelwise 8 norms over a small calibration set to select the top 9 as $16k$0.
- Forward: The activation matrix $16k$1 is split: $16k$2. Non-outlier activations are quantized and matrix-multiplied, while outliers are kept at higher precision (FP8/BF16):
$16k$3
- Backward: The same split is applied to $16k$4 and $16k$5; non-outlier regions are stochastically quantized, outliers retained at high precision. Optionally, a Random Hadamard Transform can be applied to further reduce group-level outlier variance.
OutControl preserves the accuracy of critical channels at a negligible cost due to their small proportion.
5. Empirical Evaluation
TetraJet-v2 was evaluated on the OLMo-2 LLM family with 70M, 150M, and 370M non-embedding parameters, using the OLMo-2-Mix-1124 dataset (C4 and The Pile), up to 212B tokens (batch size 1024, sequence length 4096). AdamW optimizer and cosine decay with linear LR warm-up were employed. All activations, weights, and gradients in every linear layer are quantized to NVFP4.
Baselines included Quartet (MXFP4) and the NVIDIA NVFP4 recipe (which retains partial BF16 precision).
| Method | Train PPL (70M/150M/370M) | Val PPL (70M/150M/370M) |
|---|---|---|
| BF16 (full prec) | 35.95 / 26.38 / 18.70 | 45.27 / 33.49 / 23.70 |
| Quartet (MXFP4) | 40.77 / 29.25 / 20.76 | 51.23 / 36.89 / 26.16 |
| NVIDIA recipe | 40.50 / 29.18 / 20.75 | 50.94 / 36.73 / 26.20 |
| TetraJet-v2-base | 39.26 / 28.39 / 20.23 | 49.33 / 35.88 / 25.50 |
| TetraJet-v2-full | 38.08 / 27.58 / 19.89 | 47.75 / 34.95 / 25.11 |
On downstream zero-shot tasks (370M, 200B tokens), TetraJet-v2 yields average accuracy 43.60, WikiText-103 PPL 18.06, and Pile PPL 12.81, closing 51.3% of the FP4→BF16 performance gap relative to prior state of the art.
6. Best Practices for Implementation
- Quantize all three matrix multiplications in each linear layer, using $16k$6 row and $16k$7 column grouping to match NVIDIA hardware optimally.
- Apply deterministic rounding in the forward path (activations, weights); use stochastic rounding for all gradients in backward passes.
- Limit the application of the Random Hadamard Transform to backward matrix multiplies ($16k$8 and $16k$9) to reduce group-level outlier effects.
- Deploy OsciReset after substantial learning rate decay, accumulating oscillation statistics every $16k+15$050 steps and resetting every $16k+15$1200 steps, thresholded at $16k+15$2–$16k+15$3.
- Predetermine the outlier channel set (5–10%) at initialization, retaining only these channels in FP8/BF16 for both forward and backward computation.
- Integrate TetraJet-v2 as drop-in wrappers for standard linear modules in PyTorch or TensorFlow.
7. Synthesis and Significance
TetraJet-v2 integrates unbiased double-block NVFP4 quantization, weight oscillation suppression, and selective high-precision outlier retention to enable nearly lossless 4-bit pretraining for LLMs across all linear layers. This approach halves the performance gap between full-precision and prior FP4 methods, establishing a new baseline for scalable, efficient, low-precision LLM training (Chen et al., 31 Oct 2025).