Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 4 additions & 22 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)

from torchao.prototype.awq import AWQConfig, AWQStep
from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig, quantize_
from torchao.quantization import Int4WeightOnlyConfig, quantize_
from torchao.utils import _is_fbgemm_genai_gpu_available


Expand Down Expand Up @@ -73,13 +73,7 @@ def test_awq_functionality(self):
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)

# baseline quantization
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
m_baseline = copy.deepcopy(m)
quantize_(m_baseline, base_config)

Expand Down Expand Up @@ -129,13 +123,7 @@ def test_awq_loading(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = FbgemmConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we delete this public config before people start using it? (in a separate PR)

Copy link
Contributor Author

@jerryzh168 jerryzh168 Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, will do

input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

Expand Down Expand Up @@ -189,13 +177,7 @@ def test_awq_loading_vllm(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

Expand Down
20 changes: 20 additions & 0 deletions test/quantization/quantize_/workflows/int4/test_int4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

from torchao.quantization import Int4WeightOnlyConfig, quantize_
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
from torchao.quantization.utils import compute_error
from torchao.testing.utils import TorchAOIntegrationTestCase
from torchao.utils import is_sm_at_least_90, torch_version_at_least
Expand Down Expand Up @@ -213,6 +214,25 @@ def test_cat(self, sizes):
def test_moe_weight_reshape_ops(self):
self._test_moe_weight_reshape_ops(self.config)

def test_activation_prescaling(self):
dtype = torch.bfloat16
device = "cuda"
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, self.config)
qw = linear.weight
assert isinstance(qw, SupportsActivationPreScaling), (
"Expected int4 tensor supports activation prescaling"
)
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
_ACT_PRE_SCALE = 2
qw.act_pre_scale = _ACT_PRE_SCALE
quantized = linear(input)

# making sure activation pre scaling is successfully applied to the activation
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)


instantiate_parametrized_tests(TestInt4Tensor)

Expand Down
9 changes: 7 additions & 2 deletions torchao/prototype/awq/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import torch

from torchao.core.config import AOBaseConfig
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
from torchao.quantization.quant_api import (
_linear_extra_repr,
)
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
register_quantize_module_handler,
Expand Down Expand Up @@ -105,7 +105,12 @@ def _awq_transform(
dummy_mod = DummyModule(observed_linear.weight * equalization_scale)
quant_mod = base_config_handler(dummy_mod, config.base_config)
qw = quant_mod.weight
qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale)
assert isinstance(qw, SupportsActivationPreScaling), (
"weight must support activation scaling through implementing `SupportsActivationPreScaling`"
)
# since we want to do `act` * `act_pre_scale` during runtime for speed, we'll save the
# reciprocal of the `equalization_scale`
qw.act_pre_scale = 1.0 / equalization_scale

linear = torch.nn.Linear(
observed_linear.in_features,
Expand Down
27 changes: 6 additions & 21 deletions torchao/prototype/awq/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,9 @@ def quantize_and_eval(
group_size = int(quant.split("-")[2])
print(f"running {quant} quantization with group size {group_size}")
# TODO: this is temporary, we'll be using Int4WeightOnlyConfig soon
from torchao.quantization import FbgemmConfig

# use_hqq = True
# base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
from torchao.quantization import Int4WeightOnlyConfig

base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
print(f"running {quant} prepare and calibrate")
t0 = time.time()
quant_config = AWQConfig(base_config, step="prepare")
Expand Down Expand Up @@ -267,17 +259,10 @@ def quantize_and_eval(
elif quant.startswith("int4wo"):
group_size = int(quant.split("-")[1])
print(f"running {quant} quantization with group size {group_size}")
# TODO: enable after refactor: https://github.com/pytorch/ao/pull/2474
# TODO: enable after migration: https://github.com/pytorch/ao/issues/2752
# use_hqq = "hqq" in quant
# base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)
int4_weight_only_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
quantize_(model, int4_weight_only_config)
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
quantize_(model, base_config)

if model_save_path is not None:
print(f"Saving model to {model_save_path}")
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/quantize_/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .kernel_preference import KernelPreference
from .packing_format import PackingFormat
from .protocol import SupportsActivationPreScaling
from .quantize_tensor_kwargs import (
QuantizeTensorKwargs,
_choose_quant_func_and_quantize_tensor,
Expand All @@ -9,5 +10,6 @@
"QuantizeTensorKwargs",
"KernelPreference",
"PackingFormat",
"SupportsActivationPreScaling",
"_choose_quant_func_and_quantize_tensor",
]
22 changes: 22 additions & 0 deletions torchao/quantization/quantize_/common/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""Protocols for some functionalities in tensor subclasses"""

from typing import Optional, Protocol, runtime_checkable

import torch


@runtime_checkable
class SupportsActivationPreScaling(Protocol):
"""Protocol for activation scale that should be multiplied with activation before quantization,
or before we use activation in matrix multiplications, used for algorithms like AWQ

A class that have `act_pre_scale: Optional[torch.Tensor]` attribute implements the Protocol
"""

act_pre_scale: Optional[torch.Tensor]
60 changes: 51 additions & 9 deletions torchao/quantization/quantize_/workflows/int4/int4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.


from typing import List
from typing import List, Optional

import torch
from torch.utils._python_dispatch import return_and_correct_aliasing
Expand All @@ -30,36 +30,62 @@ class Int4Tensor(TorchAOBaseTensor):
"""
int4 quantization with plain (default) packing format (for all granularities)

Tensor Attributes:
Tensor Data Attributes:
qdata: packed int4 weight, either 2D (N, K/2) or 3D (B, N, K/2), last dimension is packed
scale: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size,
dtype is the same as the original Tensor dtype
zero_point: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size,
dtype is the same as the original Tensor dtype

Non-Tensor Attributes:
Non-Tensor Data Attributes:
block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size)
shape: the shape of the original Tensor

Optional Tensor Data Attributes:
act_pre_scale (Optional[Tensor]): Optional scale for activation Tensor, if present,
we'll multiply activation Tensor with act_pre_scale before applying dynamic
quantization to activation or running quantized mm op
"""

tensor_data_names = ["qdata", "scale", "zero_point"]
tensor_attribute_names = ["block_size", "shape"]
optional_tensor_data_names = ["act_pre_scale"]

def __new__(cls, qdata, scale, zero_point, block_size, shape):
def __new__(
cls,
qdata: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
block_size: List[int],
shape: torch.Size,
act_pre_scale: Optional[torch.Tensor] = None,
):
kwargs = {}
kwargs["device"] = qdata.device
kwargs["dtype"] = scale.dtype
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, qdata, scale, zero_point, block_size, shape):
def __init__(
self,
qdata: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
block_size: List[int],
shape: torch.Size,
act_pre_scale: Optional[torch.Tensor] = None,
):
self.qdata = qdata
self.scale = scale
self.zero_point = zero_point
self.block_size = block_size
self.act_pre_scale = act_pre_scale

def _quantization_type(self):
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"
s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}"
if self.act_pre_scale is not None:
s += f", act_pre_scale.shape={self.act_pre_scale.shape}"
return s

@classmethod
def from_hp(
Expand Down Expand Up @@ -100,6 +126,7 @@ def from_hp(
zero_point=zero_point,
block_size=block_size,
shape=original_shape,
act_pre_scale=None,
)


Expand All @@ -113,12 +140,17 @@ def _(func, types, args, kwargs):
args[1],
args[2] if len(args) > 2 else None,
)
assert isinstance(weight_tensor, Int4Tensor)

assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous"
assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous"
assert weight_tensor.zero_point.is_contiguous(), (
"Expected zero_point to be contiguous"
)

if weight_tensor.act_pre_scale is not None:
input_tensor = input_tensor * weight_tensor.act_pre_scale

orig_act_size = input_tensor.size()
orig_out_features = weight_tensor.shape[-2]

Expand Down Expand Up @@ -207,12 +239,13 @@ def _(func, types, args, kwargs):
func,
args,
kwargs,
self.__class__(
Int4Tensor(
self.qdata,
self.scale,
self.zero_point,
block_size=self.block_size,
shape=self.shape,
act_pre_scale=self.act_pre_scale,
),
)

Expand All @@ -229,8 +262,13 @@ def _(func, types, args, kwargs):
zero_point = aten.slice.Tensor(self.zero_point, sz_dim, start_sz, end_sz, step)
packed_shape0, packed_shape1 = qdata.shape
new_shape = (packed_shape0, packed_shape1 * 2)
new = self.__class__(
qdata, scale, zero_point, block_size=self.block_size, shape=new_shape
new = Int4Tensor(
qdata,
scale,
zero_point,
self.block_size,
new_shape,
act_pre_scale=self.act_pre_scale,
)
return return_and_correct_aliasing(func, args, kwargs, new)

Expand Down Expand Up @@ -307,6 +345,7 @@ def _(func, types, args, kwargs):
cat_zero_point,
tensor_0.block_size,
new_shape,
act_pre_scale=tensor_0.act_pre_scale,
)
return return_and_correct_aliasing(func, args, kwargs, new)

Expand Down Expand Up @@ -351,6 +390,7 @@ def _(func, types, args, kwargs):
zero_point,
block_size,
new_shape,
act_pre_scale=self.act_pre_scale,
)
return return_and_correct_aliasing(func, args, kwargs, new)

Expand Down Expand Up @@ -439,6 +479,7 @@ def _(func, types, args, kwargs):
zero_point,
block_size,
shape,
act_pre_scale=self.act_pre_scale,
)
return return_and_correct_aliasing(func, args, kwargs, new)

Expand Down Expand Up @@ -480,6 +521,7 @@ def _(func, types, args, kwargs):
zero_point,
new_block_size,
new_shape,
act_pre_scale=self.act_pre_scale,
)
return return_and_correct_aliasing(func, args, kwargs, new)

Expand Down
Loading