-
Notifications
You must be signed in to change notification settings - Fork 321
Update AWQ implementation to not use extra wrapper tensor subclass #2753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2753
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 50af000 with merge base 6a2d975 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous" | ||
assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous" | ||
assert weight_tensor.zero_point.is_contiguous(), ( | ||
"Expected zero_point to be contiguous" | ||
) | ||
|
||
if weight_tensor.activation_scale is not None: | ||
input_tensor = input_tensor / weight_tensor.activation_scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think multiplying would be faster, can we change the scale definition to do that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, how about input_scale
, to standardize with https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.linear.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, maybe act_scale
or act_prescale
to standardize with Float8Tensor.act_quant_kwargs
if the activation is quantized, do we need to combine the scales to just do the multiplication once? this may also affect the design
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I can use act_scale
if the activation is quantized, do we need to combine the scales to just do the multiplication once? this may also affect the design
it doesn't work if activation is already quantized, we expect this to be applied before static or dynamic quantization
tensor_attribute_names = ["block_size", "shape"] | ||
|
||
def __new__(cls, qdata, scale, zero_point, block_size, shape): | ||
def __new__(cls, qdata, scale, zero_point, activation_scale, block_size, shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like an optional argument like this should be last, after non-optional ones like block_size and shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense although currently this order is mandated by the order of tensor and non-tensor attribute right now:
Lines 727 to 731 in a1a9632
`tensor_data_names` (List[str]): list of names of all requires tensor_data, order should match | |
the `__init__` list of tensor subclass | |
`optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor attributes, when defined, this will be a list of names of Tensors that can be optional | |
`tensor_attribute_names` (List[str]): list of names of non-Tensor attributes, | |
order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments and `optional_tensor_data_names` |
currently all tensor attributes comes before non-tensor attributes, is it better to have:
required_tensor_attrs, required_attrs, optional_tensor_attrs, optional_attrs for every tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it
is this kind of change bc breaking? if yes, is there an opportunity to make all of these keyword-only arguments to remove from bc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we have a stable release of these, then yes, although right now we don't have these in any of the official stable release yet
some of the data args are required for the construction to make sense I think, do you mean making all attributes keyword-only arg?
@@ -136,13 +130,7 @@ def test_awq_loading(self): | |||
calibration_data = dataset[:n_calibration_examples] | |||
|
|||
# calibrate | |||
base_config = FbgemmConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we delete this public config before people start using it? (in a separate PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, will do
assert hasattr(self, "activation_scale"), ( | ||
f"tensor {type(self)} does not have attribute `activation_scale`" | ||
) | ||
self.activation_scale = scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we also add this as an instance variable, like in L15
activation_scale: Optional[torch.Tensor]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we want to use TorchAOBaseTensor so have to define this in optional_tensor_data_names
@@ -105,7 +105,10 @@ def _awq_transform( | |||
dummy_mod = DummyModule(observed_linear.weight * equalization_scale) | |||
quant_mod = base_config_handler(dummy_mod, config.base_config) | |||
qw = quant_mod.weight | |||
qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like we should deprecate WeightTensorWithLinearActivationScaleMetadata
in favor of the mixin? Can we file an issue for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, we'll do this along with the migration or when a new tensor is added I think, no need to migrate the old implementations now
Looks like we are switching to more of a protocol style, personally I think that from typing import Protocol, runtime_checkable
@runtime_checkable
class SupportsActivationScaling(Protocol):
"""A protocol for tensor-like objects that support activation scaling."""
activation_scale: torch.Tensor | None
def set_activation_scale(self, scale: torch.Tensor) -> None:
...
# int4
class Int4Tensor(TorchAOBaseTensor):
# ... existing implementation ...
# Just add the required attribute and method
activation_scale: torch.Tensor | None = None
def set_activation_scale(self, scale: torch.Tensor) -> None:
self.activation_scale = scale
...
def _awq_transform(...):
# ...
assert isinstance(qw, SupportsActivationScaling), (
"weight must support activation scaling by conforming to the SupportsActivationScaling protocol"
) Is nicer since we don't have to do multiple inheritance, I just really dont like multiple inheritance |
OK, will try thanks |
318ae78
to
5015a81
Compare
@@ -36,30 +36,54 @@ class Int4Tensor(TorchAOBaseTensor): | |||
dtype is the same as the original Tensor dtype | |||
zero_point: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size, | |||
dtype is the same as the original Tensor dtype | |||
act_scale (Optional[Tensor]): Optional per row scale for activation Tensor, if present, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove "per row", it should work with any scale where the shapes are broadcastable to the original tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the variable name should make it clear that this scale is not related to the scale which happens inside of activation quantization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK sg
@@ -36,30 +36,54 @@ class Int4Tensor(TorchAOBaseTensor): | |||
dtype is the same as the original Tensor dtype | |||
zero_point: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size, | |||
dtype is the same as the original Tensor dtype | |||
act_scale (Optional[Tensor]): Optional per row scale for activation Tensor, if present, | |||
we'll multiply activation Tensor with act_scale before applying dynamic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how does this compose with dynamic quant? Is there a faster path where we only scale once, instead of scaling by act_scale
, then finding max, then scaling again by the inner scale?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we will first do act_scale before dynamic quantization, and then do the dynamic quantization. we don't have the faster path right now, not sure if this is needed or whether this optimization will actually yield speedup yet
Summary: We want to remove the extra wrapper tensor subclass `to_weight_tensor_with_linear_activation_scale_metadata`, althopugh it is composable with different tensor subclasses, it complicates the flow and reduces locality of code (compute logic spread between the LinearActivationScale wrapper and the real tensor subclasses) Instead we want to implement activation scaling in the tensor itself, we added a `ActivationScalingMixin` that can be inherited by the tensor subclass, that needs to have activation_scale attribute defined * has `act_scale: Optional[Tensor]` argument defined * in linear we can get the act_scale and do `input_tensor = input_tensor / self.act_scale` before calling gemm kernels Note for BC: we'll add later after this PR, right now just focuing on functionalities, no officialcheckpoint is released yet for Int4Tensor or Int4Tensor with act_scale Test Plan: python test/prototype/test_awq.py Reviewers: Subscribers: Tasks: Tags:
Summary:
We want to remove the extra wrapper tensor subclass
to_weight_tensor_with_linear_activation_scale_metadata
, althopugh it is composable with different tensor subclasses, it complicates the flow and reduces locality of code (compute logic spread between the LinearActivationScale wrapper and the real tensor subclasses)Instead we want to implement activation scaling in the tensor itself, we added a
ActivationScalingMixin
that can be inherited by the tensor subclass, that needs to have activation_scale attribute definedinput_tensor = input_tensor / activation_scale
before calling gemm kernelsTest Plan:
python test/prototype/test_awq.py
Reviewers:
Subscribers:
Tasks:
Tags: