TorchAO & Brevitas: Quantization & Sparsity
- TorchAO and Brevitas are quantization and sparsity frameworks that optimize large language models through low-precision and sparse computations in PyTorch.
- TorchAO offers an end-to-end workflow from pre-training to serving with advanced features like FP8 training and integrated sparsity, while Brevitas focuses on early quantization-aware training.
- Both frameworks employ specialized techniques to boost inference speed and reduce memory load, achieving significant performance gains with minimal accuracy loss.
TorchAO and Brevitas are quantization and sparsity frameworks for deep learning model optimization, focused on compressing, accelerating, and enabling efficient deployment of LLMs and similar architectures. TorchAO is a PyTorch-native, end-to-end solution covering pre-training through serving, tightly integrated with PyTorch APIs and ecosystem tools, offering innovations in low-precision tensor representations, sparsity, and workflow orchestration. Brevitas, in contrast, was an early quantization-aware training solution for PyTorch, supporting INT8/INT4 workflows but lacking many of TorchAO’s extensibility and workflow features. This article presents a rigorous account of TorchAO’s architecture, quantization and sparsity techniques, performance characteristics, and an in-depth comparison to Brevitas, based exclusively on reported details in (Or et al., 21 Jul 2025).
1. Architectural Overview and Workflow Integration
TorchAO implements a native, unified training-to-serving pipeline for LLM optimization, integrating quantization and sparsity at every model lifecycle stage (Or et al., 21 Jul 2025). Its architecture organizes optimization as modular steps for pre-training, fine-tuning, and inference:
- Pre-Training: FP8 low-precision training is enabled via calls to
torch.compile(model)andconvert_to_float8_training(model). TorchTitan manages distributed training across FSDP2, tensor parallelism, and asynchronous tensor parallel backends, leveraging FP8 GEMM kernels on supported hardware. - Fine-Tuning: Quantization-Aware Training (QAT) is instrumented by calling
quantize_(model, IntXQuantizationAwareTrainingConfig(...))to insert fake quantization operations. After training, quantized ops are materialized for deployment usingquantize_(model, FromIntXQuantizationAwareTrainingConfig()). Tools such as TorchTune and Axolotl are integrated at this stage. - Serving: Post-Training Quantization (PTQ) is performed by
quantize_(model, <PTQConfig>). TorchAO supports server-side backends (HuggingFace Transformers viaTorchAoConfig, vLLM, SGLang with GemLite Triton kernels) and edge deployment through ExecuTorch exports to XNNPACK and custom sub-8-bit ARM/Metal kernels.
The model dataflow is defined as:
$q(x) = \mathrm{clip}\bigl(\round(x/s), q_{\min}, q_{\max}\bigr)\times s$0
One-line transformations, as exemplified by $q(x) = \mathrm{clip}\bigl(\round(x/s), q_{\min}, q_{\max}\bigr)\times s$1 facilitate in-place quantization and backend transition. TorchAO thus unifies an otherwise fragmented space across model development, optimization, and deployment (Or et al., 21 Jul 2025).
2. Quantization Techniques and Supported Data Types
TorchAO supports both QAT and PTQ, including training and inference in novel low-precision formats beyond common INT8 and INT4 (Or et al., 21 Jul 2025).
Quantization Mapping: For tensor element , quantization to lower precision is given by
$q(x) = \mathrm{clip}\bigl(\round(x/s), q_{\min}, q_{\max}\bigr)\times s$
where is a learned/computed scale; are format-specific (e.g., INT4: , INT8: ). FP8 utilizes a custom 4-bit mantissa and 3-bit exponent plus sign.
Supported Data Types:
- Integer: INT4, INT8
- Floating: FP8 (rowwise, tensorwise), MXFP4, MXFP6, MXFP8 (prototypes)
- Mixed-precision: NF4 (for QLoRA)
PTQ and QAT APIs: PTQ is performed by calls such as quantize_(model, Int8DynamicActivationInt4Weight(group_size=32)). QAT workflows leverage:
$q(x) = \mathrm{clip}\bigl(\round(x/s), q_{\min}, q_{\max}\bigr)\times s$2
Group size, per-tensor and per-channel scaling, and flexible hardware-aligned quantization schemas are supported.
TorchAO supports QAT and PTQ for INT4/8, FP8, and experimental MXFP types, with tunable group sizes, dynamic and static quantization, and hardware-targeted quantization for efficient inference (Or et al., 21 Jul 2025).
3. Sparsity: Techniques and Implementation
TorchAO provides inbuilt sparsity support, primarily implementing 2:4 sparsity, where each four-element group contains exactly two zeros: Given weight vector , partition into blocks of size 4; for each block require
APIs include sparsify_(model, SemiSparseWeightConfig()) for 2:4, or sparsify_(model, BlockSparseWeightConfig()) for block patterns. Internally, sparse layouts are encoded via tensor subclasses that store block masks and quantized values together, allowing fusion with backend sparse GEMM kernels (notably for Tensor Core 2:4 support). A plausible implication is that these internal representations facilitate efficient, hardware-aligned execution without requiring model source code changes (Or et al., 21 Jul 2025).
4. Tensor Subclass Abstraction: QuantizedTensor and Modularity
A foundational innovation of TorchAO is the use of torch.Tensor subclassing, allowing creation of QuantizedTensor objects that encode all quantization metadata—scale, zero_point, bit_width, and layout—while supporting autograd and distributed semantics. In the forward pass, modules such as FakeQuantizedLinear wrap standard nn.Linear layers and produce QuantizedTensors as output for activations and weights.
Upon conversion for backend execution, QuantizedTensor buffers can be directly passed to CUDA, ARM, or Metal kernels, with operation dispatch managed by the __torch_dispatch__ override. This model permits hardware-agnostic quantization and sparsity, seamless interoperability with PyTorch model APIs, and avoids intrusive graph rewrites. A simplified class hierarchy is:
| Parent Class | Subclass | Key Attributes |
|---|---|---|
torch.Tensor |
QuantizedTensor | scale, zero_point, bit_width |
This abstraction supports lossless transitions between fake and real quantization during QAT and PTQ, and underpins the extensibility of TorchAO’s quantization kernel interface (Or et al., 21 Jul 2025).
5. Performance Characterization and Evaluation
TorchAO provides empirical results demonstrating low-precision model optimization benefits for Llama3-8B and related models (Or et al., 21 Jul 2025).
FP8 Training (Table 1):
| Scaling | Peak Mem (GB) | Median tok/s | Speedup |
|---|---|---|---|
| BF16 (baseline) | 47.65 | 6150 | 1.00× |
| tensorwise+FP8 | 47.77 | 7689 | 1.25× |
QAT on Llama3-8B (Table 2):
| Model | quant hellaswag | quant wikitext ppl | train tok/s | peak mem (GB) |
|---|---|---|---|---|
| PTQ | 47.0% | 26.27 | 480.3 | 17.6 |
| QAT | 52.8% (+69.8%) | 12.31 (+82.8%) | 323.0 (−33%) | 32.9 (+87%) |
PTQ for Inference (Table 3):
| Quant Technique | Accuracy | ppl | tok/s | size (GB) |
|---|---|---|---|---|
| None (BF16) | 60.01 | 7.33 | 132.41 | 15.01 |
| int4wo-64 | 58.10 | 8.25 | 268.88 | 4.76 |
| float8wo | 59.83 | 7.37 | 213.88 | 8.03 |
Serving on vLLM (Table 4):
| Precision | tok/s | per-token latency (ms) | inter-token latency (ms) |
|---|---|---|---|
| BF16 | 103.6 | 9.50 | 9.47 |
| FP8 tensorwise | 132.8 | 7.48 (−21%) | 7.47 (−21%) |
These results demonstrate that FP8 and INT4/8 quantization, as well as sparsity optimizations, yield substantial improvements in inference throughput and memory efficiency, often with minimal accuracy degradation (Or et al., 21 Jul 2025). This suggests TorchAO’s techniques are viable for production LLM deployment.
6. Comparison of TorchAO and Brevitas
The following summarizes key differentiators between TorchAO and Brevitas as reported in (Or et al., 21 Jul 2025):
| Aspect | TorchAO | Brevitas |
|---|---|---|
| API Design | 1-line in-place transforms (quantize_, sparsify_, convert_to_float8_training) |
Explicit quantizer modules, graph rewrites |
| HF Transformers Integration | TorchAoConfig enables seamless save_pretrained/push_to_hub |
Manual export of quantized parameters |
| Bit-Precision Flexibility | Supports FP8, custom MXFP, NF4, INT4/8; flexible groupings/scaling | Primarily INT8/INT4 (QAT/PTQ), no FP8/hardware-tied formats |
| Backend Support | CUDA, Triton (GemLite), XNNPACK, ARM, Metal | CUDA, CPU reference kernels |
| Workflow Coverage | Pre-training, fine-tuning, serving spanning entire LLM lifecycle | QAT/PTQ for model compression; training/serving left to user |
| Tensor Representation | torch.Tensor subclassing (preserves autograd, distributed semantics) | Module hooks, FakeQuantize ops |
| Sparsity Patterns | 2:4/block/Marlin layouts integrated in kernels | No built-in support |
| Prototyped Advances | AutoRound, AWQ, ParQ, SmoothQuant, SpinQuant | Not present |
TorchAO’s close integration with the PyTorch and HuggingFace ecosystems, broad data type support—including FP8 and custom formats—hardware-targeted quantization, sparsity, and extensibility in both research and production distinguish it from Brevitas. A plausible implication is that TorchAO serves as a preferred optimization substrate for LLM training and deployment where full-lifecycle workflow and ecosystem integration are operationally critical (Or et al., 21 Jul 2025).