Dual Learnable Ternarization for LLMs
- Dual Learnable Ternarization (DLT) is a quantization method that maps LLM weights to ternary values using learnable scale and shift parameters to address asymmetry and non-zero means.
- By partitioning weight groups and adapting thresholds based on the mean absolute value, DLT reduces clamp and rounding errors, thereby improving model compression and performance.
- Empirical evidence shows that DLT, especially when combined with Outlier-Friendly Feature Knowledge Distillation, lowers perplexity and increases accuracy on NLP benchmarks.
Dual Learnable Ternarization (DLT) is a quantization technique developed for LLMs, enabling extreme weight compression by mapping weights to ternary values while adaptively correcting both magnitude and mean per quantization group. Motivated by the presence of asymmetric outliers and non-zero means in LLM weights, DLT builds upon classic ternary weight networks by introducing a learnable scale and shift per group, achieving stronger alignment with real-world LLM weight distributions. Empirical evidence demonstrates that DLT, especially when combined with Outlier-Friendly Feature Knowledge Distillation (OFF), improves performance across standard NLP benchmarks relative to previous low-bit quantization methods (Chen et al., 11 Jun 2024).
1. Mathematical Formulation
DLT operates on partitioned groups of floating-point weights, , typically corresponding to rows, per-channel blocks, or other groupings in LLMs. For each group:
- The threshold is computed:
- Ternary codes are assigned:
- Each group is then equipped with two trainable parameters:
- Scale
- Shift
- Quantized weights are given by:
This dual-parameter scheme enables the quantized ternary palette to represent both the correct scale and groupwise mean, directly addressing the non-zero mean phenomena often observed in LLM weight groups.
2. Training Objectives and Parameter Learning
DLT integrates quantization into the training process via quantization-aware fine-tuning, optimizing and along with base model parameters. The total loss function is composed of three terms:
- Label loss (): cross-entropy between student logits and one-hot labels .
- Logits distillation (): cross-entropy between student and full-precision teacher logits, downscaled by .
- Outlier-Friendly Feature Distillation (): sum of pairwise cosine similarities between corresponding student and teacher hidden-states at each layer and token position ,
This term, weighted by , is designed to be insensitive to outliers, as substantiated by Theorem 1 in the source.
The combined objective:
is initialized by a TWN-style closed-form solution:
and . Both parameters are trained using AdamW with zero weight decay and a learning rate that of the main network weights.
3. Addressing Asymmetry and Non-Zero Means
Conventional ternarization, such as the Ternary Weight Network (TWN), employs a single scale factor and assigns zero to elements with absolute value beneath a fixed threshold, implicitly assuming symmetric, zero-mean weight distributions within groups. In contrast, LLM weight groups often exhibit asymmetric outliers and non-zero means. DLT maintains the groupwise threshold ( mean ) so that only small-magnitude weights are set to zero. By learning both (scale) and (shift), DLT dynamically corrects groupwise magnitude and offset, substantially reducing clamp-error (for the bin) and rounding-error (for the tails), and thus mitigating biases introduced by asymmetric distributions.
A plausible implication is that DLT generalizes more robustly to the heterogeneity found in modern LLM architectures compared to strictly symmetric ternarization schemes.
4. Gradient Propagation and Optimization
Gradient computation for DLT parameters is derived via chain rule from the quantized representation:
Gradients through non-differentiable ternary assignment are propagated using the Straight-Through Estimator (STE), with updates:
All parameters, including , , and , are updated via AdamW using the specified learning rates.
5. Empirical Performance and Ablations
DLT has demonstrated substantial empirical gains over prior quantization-aware schemes. Key results include:
| Model/Setting | DLT + OFF (W1.58A16) | Prior Art (W2A16, DB-LLM) | Absolute Gain |
|---|---|---|---|
| LLaMA-3-7B, C4 PPL | 13.4 | 19.2 | -5.8 (lower is better) |
| LLaMA-3-7B, Avg. Acc | 60.0% | 51.8% | +8.2% (absolute increase) |
| OPT-1.3B, C4 PPL | 18.01 | 27.34 (AWQ, 2b) | -9.33 |
| OPT-1.3B, C4 PPL | 18.01 | 31.31 (GPTQ, 2b) | -13.3 |
When replacing TWN with DLT in ablations:
- On OPT-1.3B, PPL is reduced from 22.32 to 20.83 (−1.49)
- On LLaMA-1-7B, PPL decreases from 10.10 to 9.21 (−0.89)
Further reduction is observed when combining DLT with OFF, achieving 0.77 additional PPL reduction versus logits-only distillation.
6. Integration with Outlier-Friendly Feature Knowledge Distillation
DLT is synergistically paired with OFF, which leverages cosine similarity between hidden-state features of student (ternarized) and teacher (full-precision) models. This approach is robust to outliers, allowing semantic and distributional information to be transferred despite aggressive ternarization. The addition of OFF to DLT further improves both perplexity and downstream task accuracy.
In summary, Dual Learnable Ternarization introduces a minimal but powerful extension of groupwise ternary quantization by learning both scale and shift per group, directly addressing real-world LLM weight asymmetries. This results in lower quantization error, superior empirical performance across LLM families, and compatibility with quantization-aware fine-tuning protocols utilizing STE and knowledge distillation (Chen et al., 11 Jun 2024).