PyTorch native INT8 quantization API for TorchAO

Quantization INT8 PyTorch TorchAO Inference Optimization
Author: Namgyu Youn
Status: Active Development
Type: v0 Batch Research Project
Date: Dec 2025

Abstract

This report presents the design and implementation of Int8Tensor, a quantized tensor subclass for PyTorch's TorchAO library that enables INT8 inference for neural networks. By representing weights and activations in 8-bit integers instead of 16/32-bit floats, INT8 quantization reduces memory footprint by up to 4× and can significantly accelerate inference on hardware with native INT8 support. The implementation supports both dynamic activation quantization (INT8×INT8 matrix multiplication) and weight-only quantization (FP16/BF16×INT8), providing flexibility for different deployment scenarios. Through five major design iterations, we evolved from direct block-size specification to high-level Granularity abstractions, optimized kernel implementations, and self-contained utility functions. This report documents the technical decisions, lessons learned, and future research directions for low-latency inference optimization.

1. Introduction

Int8Tensor is a quantized tensor subclass in TorchAO that enables INT8 inference for neural networks. By representing weights and activations in 8-bit integers instead of 16/32-bit floats, INT8 quantization reduces memory footprint and can accelerate inference on hardware with INT8 support. See Section 3.1 for how asymmetric integer quantization works mathematically.

The implementation has been contributed to PyTorch's TorchAO library through PR #3391, bringing efficient INT8 quantization capabilities to the broader PyTorch ecosystem.

1.1 Quantization Modes

Int8Tensor supports two complementary quantization strategies:

  • Dynamic activation quantization (INT8×INT8): Both weights and activations are quantized to INT8. Activations are quantized at runtime during inference. This mode maximizes memory savings and computational efficiency on hardware with native INT8 matrix multiplication support.
  • Weight-only quantization (FP16/BF16×INT8): Only weights are pre-quantized to INT8, while activations remain in floating point. This mode trades some memory savings for simpler deployment and broader hardware compatibility.

1.2 Design Objectives

The implementation was designed with three core requirements:

  1. Seamless PyTorch integration: Via __torch_dispatch__ override, allowing F.linear to work transparently with quantized weights without modifying user code.
  2. Flexible quantization granularity: Support for per-tensor and per-row quantization to balance accuracy and efficiency across different use cases.
  3. Compatibility with tensor operations: Enable standard operations like slice and select for model parallelism and dynamic batching scenarios.

2. Design Evolution

This section chronicles the major architectural changes across five design iterations, explaining the problems encountered and the reasoning behind each evolution.

2.1 Granularity API Evolution

Quantization granularity determines how scale factors are shared across tensor elements. Per-tensor quantization uses a single scale for the entire tensor, while per-row quantization uses one scale per output channel, typically providing better accuracy at the cost of slightly increased memory for scale storage.

Initial Approach (V1-V3): Direct Block Size Specification

# User must calculate block_size manually Int8Tensor.from_hp(weight, block_size=[1, weight.shape[1]]) # per-row Int8Tensor.from_hp(weight, block_size=[weight.shape[0], weight.shape[1]]) # per-tensor

This required users to understand the relationship between block_size and quantization granularity. For per-row quantization of a [N, K] weight matrix, the block size is [1, K], meaning each row is quantized independently.

Problem: This low-level API exposed implementation details and led to user errors, especially for those unfamiliar with quantization internals.

Final Approach (V4-V5): Abstract Granularity Objects

from torchao.quantization.granularity import PerRow, PerTensor Int8Tensor.from_hp(weight, granularity=PerRow()) Int8Tensor.from_hp(weight, granularity=PerTensor())

The block_size is computed internally via get_block_size(tensor.shape, granularity), hiding implementation details from users. The abstraction reduces user error and aligns with TorchAO's API conventions.

Design Decision: Storing both granularity and computed block_size avoids repeated computation during slice operations while maintaining API clarity.

2.2 Dequantization Implementation

Dequantization converts INT8 data back to floating point by multiplying with scale factors. The scale tensor shape depends on granularity: scalar for per-tensor, [N] or [N, 1] for per-row.

Initial Approach (V1-V3): Manual Scale Broadcasting

def dequantize(self, output_dtype=None): qdata_fp = self.qdata.to(output_dtype) scale_expanded = _maybe_expand_scale_to_tensor_shape(self.scale, self.qdata.shape) return qdata_fp * scale_expanded

This approach manually handled scale tensor broadcasting, introducing edge cases for different granularity and tensor dimensions.

Final Approach (V4-V5): TorchAO Primitive Reuse

def dequantize(self, output_dtype=None): block_size = get_block_size(self.qdata.shape, self.granularity) return dequantize_affine( input=self.qdata, block_size=block_size, scale=self.scale, output_dtype=output_dtype, )
Rationale: dequantize_affine is a validated TorchAO primitive that correctly handles all granularity cases. Reusing it reduces maintenance burden and ensures consistency with other quantization methods in the library.

2.3 Linear Operation Optimization

The F.linear operation computes Y = X @ W.T + bias. For quantized inference, we need to handle integer matrix multiplication and scale application efficiently.

INT8×INT8 Path Evolution

Version Implementation Issue
V1 Float conversion before matmul Loses benefit of integer arithmetic, increases memory
V2 int64 matmul attempt Works on CPU, fails on CUDA ("addmm_cuda" not implemented for 'Int')
V3-V5 Optimized int_scaled_matmul kernel ✓ Works on both CPU and CUDA
from torchao.kernel import int_scaled_matmul y_dot_scaled = int_scaled_matmul(tmp, w_vals_t, x_scales.reshape(-1, 1)) result = y_dot_scaled * w_scales
Key Insight: int_scaled_matmul is a TorchAO kernel that handles INT8 matmul with fused scale application, working on both CPU and CUDA. It avoids overhead of manual dtype conversions.

Weight-Only Path Optimization

# V3-V5: Lazy dequantization - only dtype conversion, scale applied after matmul w_vals_t = weight.qdata.t().to(activation.dtype) m = torch.mm(activation.reshape(-1, activation.shape[-1]), w_vals_t) result = m * weight.scale

By deferring scale multiplication until after the matmul, we avoid allocating a full floating-point weight tensor. The INT8 weight is cast to the activation dtype element-wise during the matmul, reducing peak memory usage.

2.4 Scale Slicing Implementation

When slicing an Int8Tensor (e.g., for tensor parallelism), the associated scale tensor must be sliced consistently with the underlying quantization granularity.

Initial Approach (V4): External Dependency

from torchao.float8.inference import _slice_scale_for_dimension sliced_scale = _slice_scale_for_dimension( self.scale, self.qdata.shape, dim, start, end, step )

This created an unexpected dependency on the Float8 module.

Final Approach (V5): Self-Contained Implementation

def _slice_scale(scale, data_shape, dim, start, end, step): # Case 1: Per-tensor (scalar scale) if scale.numel() <= 1: return scale # Case 2: Per-row (1D scale) if scale.ndim == 1: return aten.slice.Tensor(scale, 0, start, end, step) if dim == 0 else scale # Case 3: Per-block (2D scale) block_size_for_dim = data_shape[dim] // scale.shape[dim] scale_start = start // block_size_for_dim scale_end = (end + block_size_for_dim - 1) // block_size_for_dim return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1)
Rationale: A self-contained implementation removes the Float8 dependency and makes the Int8Tensor module easier to understand in isolation. Explicit case handling also improves code maintainability.

3. Technical Foundations

3.1 Asymmetric Integer Quantization

Asymmetric integer quantization maps high-precision floating-point numbers to discrete integer levels. The quantization process is formulated as:

Qx = ⌈X/s + z⌋

where:

  • X is the floating-point tensor
  • Qx is its n-bit quantized counterpart
  • s is the scaling factor: s = (Xmax - Xmin) / (qmax - qmin)
  • z is the zero point: z = ⌈qmin - Xmin/s⌋

The dequantized tensor is recovered as:

X̂ = Q(X) = (Qx - z) · s

3.2 Supported Tensor Operations

PyTorch tensor subclasses can override standard tensor operations via __torch_dispatch__. When a PyTorch operation is called on an Int8Tensor, PyTorch dispatches to our custom implementation instead of the default.

Operation Description Limitations
aten.linear.default INT8×INT8 or FP×INT8 matrix multiplication None
aten.slice.Tensor Tensor slicing with scale adjustment dim ∈ [0,1,2], step=1 only
aten.select.int Single index selection dim=0 only
aten.index.Tensor Advanced indexing None

4. Key Lessons Learned

4.1 Prefer High-Level APIs

Initial versions exposed block_size directly, requiring users to understand that per-row quantization of a [N, K] matrix needs block_size=[1, K]. The resolution was to abstract quantization intent into explicit Granularity objects that express intent rather than implementation.

Lesson: When designing tensor APIs, prefer semantic abstractions over low-level parameters.

4.2 Reuse Validated Primitives

Manual implementations of dequantization involved scale tensor broadcasting logic that varied by granularity and tensor dimension. Edge cases were easy to miss. TorchAO's dequantize_affine primitive already handles these cases correctly.

Lesson: Before implementing custom logic, check if existing primitives can be reused. This reduces bugs and maintenance burden.

4.3 Consider Hardware Constraints Early

The INT8 matmul implementation went through three iterations before settling on int_scaled_matmul. The V2 approach assumed PyTorch would support integer matmul on GPU, which it does not.

Lesson: For quantization work, verify kernel availability on target hardware before committing to an approach.

4.4 Minimize External Dependencies

Using _slice_scale_for_dimension from torchao.float8.inference created coupling between unrelated modules. The self-contained _slice_scale implementation in V5 is slightly more code but eliminates this dependency.

Lesson: For utility functions, consider whether the dependency is worth the coupling cost.

5. Future Work and Benchmarking

Custom Kernel Development

We plan to develop custom CUDA and Triton kernels optimized for low-latency inference scenarios, specifically targeting:

  • Batch size: 1 (single-user inference)
  • Sequence length: Under 256 tokens
  • Target model: AWQ 4-bit quantized Llama 3.1 8B
  • Hardware: Consumer-grade GPUs (RTX 4090, RTX 3090)

5.1 Performance Metrics

Our optimization efforts will focus on two critical inference metrics:

  • TTFT (Time To First Token): The latency from request submission to first token generation, critical for perceived responsiveness
  • TPOT (Time Per Output Token): The average time to generate subsequent tokens, determining throughput

5.2 Planned Benchmarking Methodology

Through systematic profiling with PyTorch Profiler and NVIDIA Nsight Compute, we will:

  1. Identify critical bottlenecks in the standard Hugging Face inference pipeline
  2. Analyze optimization strategies employed by vLLM, particularly:
    • Tiling configurations for matrix operations
    • Shared memory utilization patterns
    • Kernel fusion opportunities
  3. Benchmark against both vanilla Hugging Face and vLLM baselines
  4. Validate model quality through perplexity measurements to ensure no degradation
Status: Comprehensive performance comparisons and benchmark results will be published upon completion of the optimization work. Initial profiling is currently underway.

6. Version History

Version Key Changes
V1 Initial implementation with block_size, float matmul
V2 Added _shape tracking, int64 matmul attempt
V3 int_scaled_matmul kernel, lazy dequantize for weight-only
V4 Granularity abstraction, dequantize_affine primitive
V5 Constructor fix, self-contained _slice_scale

7. References

  1. TorchAO Contributors, "TorchAO: PyTorch-Native Training-to-Serving Model Optimization," GitHub, 2024. Available: https://github.com/pytorch/ao
  2. J. Lin, J. Tang, H. Tang, S. Yang, W.-M. Chen, W.-C. Wang, G. Xiao, X. Dang, C. Gan, and S. Han, "AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration," in Proc. MLSys, 2024.
  3. Y. Lin, H. Tang, S. Yang, Z. Zhang, G. Xiao, C. Gan, and S. Han, "QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving," arXiv preprint arXiv:2405.04532, 2024.

← Back to Research Projects