Skip to content

[CPU] Introduce Int4TinyGemmCpuTensor to replace Int4CPULayout in AQT #2798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)

from torchao.quantization import (
Int4WeightOnlyConfig,
quantize_,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
torch_version_at_least,
)


def get_config(group_size):
return Int4WeightOnlyConfig(
group_size=group_size,
packing_format="int4_tinygemm_cpu",
version=2,
)


@unittest.skipIf(not torch_version_at_least("2.6.0"), "Need pytorch 2.6+")
class TestInt4TinyGemmCpuTensor(TestCase):
@parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 512, 128),
((2, 32, 128), 256, 12),
],
)
@parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@parametrize("group_size", [32, 64, 128])
def test_linear(self, sizes, dtype, group_size):
device = "cpu"
M, N, K = sizes
input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(group_size))
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)

compiled_linear = torch.compile(linear)
quantized_and_compiled = compiled_linear(input)
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)

@parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
def test_module_path(self, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype)
quantize_(linear, get_config(group_size=128))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4TinyGemmCpuTensor'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Int4TinyGemmCpuTensor'>",
)


instantiate_parametrized_tests(TestInt4TinyGemmCpuTensor)


if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
Int4MarlinSparseTensor,
Int4PreshuffledTensor,
Int4Tensor,
Int4TinyGemmCpuTensor,
IntxUnpackedTensor,
)
from .smoothquant import (
Expand Down Expand Up @@ -164,6 +165,7 @@
"Int4MarlinSparseTensor",
"IntxUnpackedTensor",
"Float8Tensor",
"Int4TinyGemmCpuTensor",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
7 changes: 7 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
Int4MarlinSparseTensor,
Int4PreshuffledTensor,
Int4Tensor,
Int4TinyGemmCpuTensor,
IntxUnpackedTensor,
QuantizeTensorToFloat8Kwargs,
)
Expand Down Expand Up @@ -1080,6 +1081,12 @@ def _int4_weight_only_quantize_tensor(weight, config):
block_size,
)
return new_weight
elif packing_format == PackingFormat.INT4_TINYGEMM_CPU:
new_weight = Int4TinyGemmCpuTensor.from_hp(
weight,
block_size,
)
return new_weight
else:
raise ValueError(f"Unsupported packing format: {packing_format}")

Expand Down
5 changes: 5 additions & 0 deletions torchao/quantization/quantize_/common/packing_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ class PackingFormat(str, Enum):
Unpacked means the subbyte quantized data is stored as int8
"""
UNPACKED_TO_INT8 = "unpacked_to_int8"

"""
int4_tinygemm_cpu is referring to the format used by int4 weight-only quantization on CPU, which is a groupwise quantization format.
"""
INT4_TINYGEMM_CPU = "int4_tinygemm_cpu"
4 changes: 4 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from .int4.int4_tensor import (
Int4Tensor,
)
from .int4.int4_tinygemm_cpu_tensor import (
Int4TinyGemmCpuTensor,
)
from .intx.intx_unpacked_tensor import (
IntxUnpackedTensor,
)
Expand All @@ -21,5 +24,6 @@
"Int4MarlinSparseTensor",
"Float8Tensor",
"QuantizeTensorToFloat8Kwargs",
"Int4TinyGemmCpuTensor",
"IntxUnpackedTensor",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.


from typing import List

import torch

from torchao.quantization.quant_primitives import (
MappingType,
_choose_qparams_affine_tinygemm,
_quantize_affine_tinygemm,
)
from torchao.utils import (
TorchAOBaseTensor,
)

__all__ = [
"Int4TinyGemmCpuTensor",
]

aten = torch.ops.aten


class Int4TinyGemmCpuTensor(TorchAOBaseTensor):
"""
int4 weight-only quantization on CPU with tinygemm (groupwise quantization only)
Tensor Attributes:
qdata: preshuffled and packed int4 weight for tinygemm, always viewed as a 2D (N, K/2) tensor, last dimension is packed
preshuffling is specific to CPU kernels, see Note below.
scale_and_zero: (K/group_size, N, 2), dtype is the same as the original Tensor dtype
Non-Tensor Attributes:
block_size: the block size for quantization, representing the granularity, for groupwise quantization, will have block_size (1, group_size).
we only support group_size = 32/64/128.
shape: shape of the original Tensor
Note on Details for data layout for CPU tinygemm kernel:
We use AVX512 to compute TINYGEMM on CPU. We can also leverage AVX512_VNNI and AMX instructions with torch.compile and max-autotune.
For data locality, we preshuffle the data in plain layout (N, K/2) to (N/block_n, K, block_n/2), where block_n = 64/32/16.
See https://github.com/pytorch/pytorch/blob/32eee8ed225d9f10fbbcb38c24b8b44c24c0c97c/aten/src/ATen/native/cpu/int4mm_kernel.cpp#L583 for more details.
"""

tensor_data_names = ["qdata", "scale_and_zero"]
tensor_attribute_names = ["block_size", "shape"]

def __new__(
cls,
qdata,
scale_and_zero,
block_size,
shape,
):
kwargs = {}
kwargs["device"] = qdata.device
kwargs["dtype"] = scale_and_zero.dtype
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
qdata: torch.Tensor,
scale_and_zero: torch.Tensor,
block_size: List[int],
shape: torch.Size,
):
self.qdata = qdata
self.scale_and_zero = scale_and_zero
self.block_size = block_size

def _quantization_type(self):
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"

@classmethod
def from_hp(
cls,
w: torch.Tensor,
block_size: List[int],
):
assert w.ndim == 2 and w.device.type == "cpu", (
f"Expecting 2D tensor on CPU, but got: {w.shape} on {w.device.type}"
)
assert len(block_size) == w.ndim
assert block_size[0] == 1 and block_size[1] in (32, 64, 128), (
f"Expecting groupwise quantization with group size = 32/64/128, but got block_size: {block_size}"
)
original_shape = w.shape
mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
scale_dtype = None
zero_point_dtype = w.dtype
scale, zero_point = _choose_qparams_affine_tinygemm(
w,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
)
int_data = _quantize_affine_tinygemm(
w,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
)
assert int_data.dtype == torch.int32, (
"torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
)
packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
int_data,
1, # innerKTiles is not needed for CPU
)

scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros

scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype)
return Int4TinyGemmCpuTensor(
qdata=packed_weight,
scale_and_zero=scale_and_zero,
block_size=block_size,
shape=original_shape,
)


implements = Int4TinyGemmCpuTensor.implements


@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
assert input_tensor.device.type == "cpu", (
f"For CPU device only but got: {input_tensor.device}"
)
assert isinstance(weight_tensor, Int4TinyGemmCpuTensor), (
f"Expected weight_tensor to be Int4TinyGemmCpuTensor, got: {type(weight_tensor)}"
)
assert weight_tensor.block_size[0] == 1, (
f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
)
assert input_tensor.shape[-1] == weight_tensor.shape[1], (
f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}"
)

act_mat = input_tensor
packed_weight = weight_tensor.qdata
scale_and_zero = weight_tensor.scale_and_zero

orig_act_size = act_mat.size()
orig_dtype = act_mat.dtype

# reshape to 2D
act_mat = act_mat.reshape(-1, act_mat.shape[-1])

# groupwise int4 quantization
groupsize = weight_tensor.block_size[1]
y = torch.ops.aten._weight_int4pack_mm_for_cpu(
act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
)

# remove out_feature padding
assert weight_tensor.ndim == 2
orig_out_features = weight_tensor.shape[-2]
y = y[:, :orig_out_features]
y = y.reshape(*orig_act_size[:-1], orig_out_features)

if bias is not None:
y += bias
return y.to(orig_dtype)


Int4TinyGemmCpuTensor.__module__ = "torchao.quantization"

# Allow a model with Int4TinyGemmCpuTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([Int4TinyGemmCpuTensor])
Loading