PyTorch native INT8 quantization API for TorchAO
Abstract
This report presents the design and implementation of Int8Tensor, a quantized tensor subclass for PyTorch's TorchAO library that enables INT8 inference for neural networks. By representing weights and activations in 8-bit integers instead of 16/32-bit floats, INT8 quantization reduces memory footprint by up to 4× and can significantly accelerate inference on hardware with native INT8 support. The implementation supports both dynamic activation quantization (INT8×INT8 matrix multiplication) and weight-only quantization (FP16/BF16×INT8), providing flexibility for different deployment scenarios. Through five major design iterations, we evolved from direct block-size specification to high-level Granularity abstractions, optimized kernel implementations, and self-contained utility functions. This report documents the technical decisions, lessons learned, and future research directions for low-latency inference optimization.
1. Introduction
Int8Tensor is a quantized tensor subclass in TorchAO that enables INT8 inference for neural networks. By representing weights and activations in 8-bit integers instead of 16/32-bit floats, INT8 quantization reduces memory footprint and can accelerate inference on hardware with INT8 support. See Section 3.1 for how asymmetric integer quantization works mathematically.
The implementation has been contributed to PyTorch's TorchAO library through PR #3391, bringing efficient INT8 quantization capabilities to the broader PyTorch ecosystem.
1.1 Quantization Modes
Int8Tensor supports two complementary quantization strategies:
- Dynamic activation quantization (INT8×INT8): Both weights and activations are quantized to INT8. Activations are quantized at runtime during inference. This mode maximizes memory savings and computational efficiency on hardware with native INT8 matrix multiplication support.
- Weight-only quantization (FP16/BF16×INT8): Only weights are pre-quantized to INT8, while activations remain in floating point. This mode trades some memory savings for simpler deployment and broader hardware compatibility.
1.2 Design Objectives
The implementation was designed with three core requirements:
- Seamless PyTorch integration: Via
__torch_dispatch__override, allowingF.linearto work transparently with quantized weights without modifying user code. - Flexible quantization granularity: Support for per-tensor and per-row quantization to balance accuracy and efficiency across different use cases.
- Compatibility with tensor operations: Enable standard operations like slice and select for model parallelism and dynamic batching scenarios.
2. Design Evolution
This section chronicles the major architectural changes across five design iterations, explaining the problems encountered and the reasoning behind each evolution.
2.1 Granularity API Evolution
Quantization granularity determines how scale factors are shared across tensor elements. Per-tensor quantization uses a single scale for the entire tensor, while per-row quantization uses one scale per output channel, typically providing better accuracy at the cost of slightly increased memory for scale storage.
Initial Approach (V1-V3): Direct Block Size Specification
This required users to understand the relationship between block_size and quantization granularity. For per-row quantization of a [N, K] weight matrix, the block size is [1, K], meaning each row is quantized independently.
Final Approach (V4-V5): Abstract Granularity Objects
The block_size is computed internally via get_block_size(tensor.shape, granularity), hiding implementation details from users. The abstraction reduces user error and aligns with TorchAO's API conventions.
granularity and computed block_size avoids repeated computation during slice operations while maintaining API clarity.
2.2 Dequantization Implementation
Dequantization converts INT8 data back to floating point by multiplying with scale factors. The scale tensor shape depends on granularity: scalar for per-tensor, [N] or [N, 1] for per-row.
Initial Approach (V1-V3): Manual Scale Broadcasting
This approach manually handled scale tensor broadcasting, introducing edge cases for different granularity and tensor dimensions.
Final Approach (V4-V5): TorchAO Primitive Reuse
dequantize_affine is a validated TorchAO primitive that correctly handles all granularity cases. Reusing it reduces maintenance burden and ensures consistency with other quantization methods in the library.
2.3 Linear Operation Optimization
The F.linear operation computes Y = X @ W.T + bias. For quantized inference, we need to handle integer matrix multiplication and scale application efficiently.
INT8×INT8 Path Evolution
| Version | Implementation | Issue |
|---|---|---|
| V1 | Float conversion before matmul | Loses benefit of integer arithmetic, increases memory |
| V2 | int64 matmul attempt | Works on CPU, fails on CUDA ("addmm_cuda" not implemented for 'Int') |
| V3-V5 | Optimized int_scaled_matmul kernel |
✓ Works on both CPU and CUDA |
int_scaled_matmul is a TorchAO kernel that handles INT8 matmul with fused scale application, working on both CPU and CUDA. It avoids overhead of manual dtype conversions.
Weight-Only Path Optimization
By deferring scale multiplication until after the matmul, we avoid allocating a full floating-point weight tensor. The INT8 weight is cast to the activation dtype element-wise during the matmul, reducing peak memory usage.
2.4 Scale Slicing Implementation
When slicing an Int8Tensor (e.g., for tensor parallelism), the associated scale tensor must be sliced consistently with the underlying quantization granularity.
Initial Approach (V4): External Dependency
This created an unexpected dependency on the Float8 module.
Final Approach (V5): Self-Contained Implementation
3. Technical Foundations
3.1 Asymmetric Integer Quantization
Asymmetric integer quantization maps high-precision floating-point numbers to discrete integer levels. The quantization process is formulated as:
where:
- X is the floating-point tensor
- Qx is its n-bit quantized counterpart
- s is the scaling factor: s = (Xmax - Xmin) / (qmax - qmin)
- z is the zero point: z = ⌈qmin - Xmin/s⌋
The dequantized tensor is recovered as:
3.2 Supported Tensor Operations
PyTorch tensor subclasses can override standard tensor operations via __torch_dispatch__. When a PyTorch operation is called on an Int8Tensor, PyTorch dispatches to our custom implementation instead of the default.
| Operation | Description | Limitations |
|---|---|---|
aten.linear.default |
INT8×INT8 or FP×INT8 matrix multiplication | None |
aten.slice.Tensor |
Tensor slicing with scale adjustment | dim ∈ [0,1,2], step=1 only |
aten.select.int |
Single index selection | dim=0 only |
aten.index.Tensor |
Advanced indexing | None |
4. Key Lessons Learned
4.1 Prefer High-Level APIs
Initial versions exposed block_size directly, requiring users to understand that per-row quantization of a [N, K] matrix needs block_size=[1, K]. The resolution was to abstract quantization intent into explicit Granularity objects that express intent rather than implementation.
4.2 Reuse Validated Primitives
Manual implementations of dequantization involved scale tensor broadcasting logic that varied by granularity and tensor dimension. Edge cases were easy to miss. TorchAO's dequantize_affine primitive already handles these cases correctly.
4.3 Consider Hardware Constraints Early
The INT8 matmul implementation went through three iterations before settling on int_scaled_matmul. The V2 approach assumed PyTorch would support integer matmul on GPU, which it does not.
4.4 Minimize External Dependencies
Using _slice_scale_for_dimension from torchao.float8.inference created coupling between unrelated modules. The self-contained _slice_scale implementation in V5 is slightly more code but eliminates this dependency.
5. Future Work and Benchmarking
Custom Kernel Development
We plan to develop custom CUDA and Triton kernels optimized for low-latency inference scenarios, specifically targeting:
- Batch size: 1 (single-user inference)
- Sequence length: Under 256 tokens
- Target model: AWQ 4-bit quantized Llama 3.1 8B
- Hardware: Consumer-grade GPUs (RTX 4090, RTX 3090)
5.1 Performance Metrics
Our optimization efforts will focus on two critical inference metrics:
- TTFT (Time To First Token): The latency from request submission to first token generation, critical for perceived responsiveness
- TPOT (Time Per Output Token): The average time to generate subsequent tokens, determining throughput
5.2 Planned Benchmarking Methodology
Through systematic profiling with PyTorch Profiler and NVIDIA Nsight Compute, we will:
- Identify critical bottlenecks in the standard Hugging Face inference pipeline
- Analyze optimization strategies employed by vLLM, particularly:
- Tiling configurations for matrix operations
- Shared memory utilization patterns
- Kernel fusion opportunities
- Benchmark against both vanilla Hugging Face and vLLM baselines
- Validate model quality through perplexity measurements to ensure no degradation
6. Version History
| Version | Key Changes |
|---|---|
| V1 | Initial implementation with block_size, float matmul |
| V2 | Added _shape tracking, int64 matmul attempt |
| V3 | int_scaled_matmul kernel, lazy dequantize for weight-only |
| V4 | Granularity abstraction, dequantize_affine primitive |
| V5 | Constructor fix, self-contained _slice_scale |
7. References
- TorchAO Contributors, "TorchAO: PyTorch-Native Training-to-Serving Model Optimization," GitHub, 2024. Available: https://github.com/pytorch/ao
- J. Lin, J. Tang, H. Tang, S. Yang, W.-M. Chen, W.-C. Wang, G. Xiao, X. Dang, C. Gan, and S. Han, "AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration," in Proc. MLSys, 2024.
- Y. Lin, H. Tang, S. Yang, Z. Zhang, G. Xiao, C. Gan, and S. Han, "QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving," arXiv preprint arXiv:2405.04532, 2024.