Skip to content

Conversation

jerryzh168
Copy link
Contributor

Summary:
We want to remove the extra wrapper tensor subclass to_weight_tensor_with_linear_activation_scale_metadata, althopugh it is composable with different tensor subclasses, it complicates the flow and reduces locality of code (compute logic spread between the LinearActivationScale wrapper and the real tensor subclasses)

Instead we want to implement activation scaling in the tensor itself, we added a ActivationScalingMixin that can be inherited by the tensor subclass, that needs to have activation_scale attribute defined

  • has a set_activation_scale method defined, we can call this to set the activation scale
  • in linear we can get the activation_scale and do input_tensor = input_tensor / activation_scale before calling gemm kernels

Test Plan:
python test/prototype/test_awq.py

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Aug 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2753

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 50af000 with merge base 6a2d975 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jerryzh168 jerryzh168 requested a review from andrewor14 August 12, 2025 23:29
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 12, 2025
@jerryzh168 jerryzh168 requested review from vkuzo and drisspg August 12, 2025 23:29
@jerryzh168 jerryzh168 added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 12, 2025
assert weight_tensor.qdata.is_contiguous(), "Expected qdata to be contiguous"
assert weight_tensor.scale.is_contiguous(), "Expected scale to be contiguous"
assert weight_tensor.zero_point.is_contiguous(), (
"Expected zero_point to be contiguous"
)

if weight_tensor.activation_scale is not None:
input_tensor = input_tensor / weight_tensor.activation_scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think multiplying would be faster, can we change the scale definition to do that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, maybe act_scale or act_prescale to standardize with Float8Tensor.act_quant_kwargs

if the activation is quantized, do we need to combine the scales to just do the multiplication once? this may also affect the design

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I can use act_scale

if the activation is quantized, do we need to combine the scales to just do the multiplication once? this may also affect the design

it doesn't work if activation is already quantized, we expect this to be applied before static or dynamic quantization

tensor_attribute_names = ["block_size", "shape"]

def __new__(cls, qdata, scale, zero_point, block_size, shape):
def __new__(cls, qdata, scale, zero_point, activation_scale, block_size, shape):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like an optional argument like this should be last, after non-optional ones like block_size and shape

Copy link
Contributor Author

@jerryzh168 jerryzh168 Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense although currently this order is mandated by the order of tensor and non-tensor attribute right now:

ao/torchao/utils.py

Lines 727 to 731 in a1a9632

`tensor_data_names` (List[str]): list of names of all requires tensor_data, order should match
the `__init__` list of tensor subclass
`optional_tensor_data_names` (List[str]): it's optional to define this field to have the additional boilerplate functions been implemented for you, but this will be need if there are some optional Tensor attributes, when defined, this will be a list of names of Tensors that can be optional
`tensor_attribute_names` (List[str]): list of names of non-Tensor attributes,
order should match the `__init__` list of tensor subclass, following all the `tensor_data_names` arguments and `optional_tensor_data_names`

currently all tensor attributes comes before non-tensor attributes, is it better to have:
required_tensor_attrs, required_attrs, optional_tensor_attrs, optional_attrs for every tensor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it

is this kind of change bc breaking? if yes, is there an opportunity to make all of these keyword-only arguments to remove from bc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we have a stable release of these, then yes, although right now we don't have these in any of the official stable release yet

some of the data args are required for the construction to make sense I think, do you mean making all attributes keyword-only arg?

@@ -136,13 +130,7 @@ def test_awq_loading(self):
calibration_data = dataset[:n_calibration_examples]

# calibrate
base_config = FbgemmConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we delete this public config before people start using it? (in a separate PR)

Copy link
Contributor Author

@jerryzh168 jerryzh168 Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, will do

assert hasattr(self, "activation_scale"), (
f"tensor {type(self)} does not have attribute `activation_scale`"
)
self.activation_scale = scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also add this as an instance variable, like in L15

activation_scale: Optional[torch.Tensor]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want to use TorchAOBaseTensor so have to define this in optional_tensor_data_names

@@ -105,7 +105,10 @@ def _awq_transform(
dummy_mod = DummyModule(observed_linear.weight * equalization_scale)
quant_mod = base_config_handler(dummy_mod, config.base_config)
qw = quant_mod.weight
qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like we should deprecate WeightTensorWithLinearActivationScaleMetadata in favor of the mixin? Can we file an issue for that?

Copy link
Contributor Author

@jerryzh168 jerryzh168 Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we'll do this along with the migration or when a new tensor is added I think, no need to migrate the old implementations now

@drisspg
Copy link
Contributor

drisspg commented Aug 14, 2025

Looks like we are switching to more of a protocol style, personally I think that

from typing import Protocol, runtime_checkable

@runtime_checkable
class SupportsActivationScaling(Protocol):
    """A protocol for tensor-like objects that support activation scaling."""
    
    activation_scale: torch.Tensor | None
  
    def set_activation_scale(self, scale: torch.Tensor) -> None:
    	...


# int4
class Int4Tensor(TorchAOBaseTensor):
    # ... existing implementation ...
    
    # Just add the required attribute and method
    activation_scale: torch.Tensor | None = None

    def set_activation_scale(self, scale: torch.Tensor) -> None:
        self.activation_scale = scale
...
        

  def _awq_transform(...):
    # ...
    assert isinstance(qw, SupportsActivationScaling), (
        "weight must support activation scaling by conforming to the SupportsActivationScaling protocol"
    )

Is nicer since we don't have to do multiple inheritance, I just really dont like multiple inheritance

@jerryzh168
Copy link
Contributor Author

Looks like we are switching to more of a protocol style, personally I think that

from typing import Protocol, runtime_checkable

@runtime_checkable
class SupportsActivationScaling(Protocol):
    """A protocol for tensor-like objects that support activation scaling."""
    
    activation_scale: torch.Tensor | None
  
    def set_activation_scale(self, scale: torch.Tensor) -> None:
    	...


# int4
class Int4Tensor(TorchAOBaseTensor):
    # ... existing implementation ...
    
    # Just add the required attribute and method
    activation_scale: torch.Tensor | None = None

    def set_activation_scale(self, scale: torch.Tensor) -> None:
        self.activation_scale = scale
...
        

  def _awq_transform(...):
    # ...
    assert isinstance(qw, SupportsActivationScaling), (
        "weight must support activation scaling by conforming to the SupportsActivationScaling protocol"
    )

Is nicer since we don't have to do multiple inheritance, I just really dont like multiple inheritance

OK, will try thanks

@jerryzh168 jerryzh168 force-pushed the update-awq branch 4 times, most recently from 318ae78 to 5015a81 Compare August 15, 2025 20:56
@@ -36,30 +36,54 @@ class Int4Tensor(TorchAOBaseTensor):
dtype is the same as the original Tensor dtype
zero_point: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size,
dtype is the same as the original Tensor dtype
act_scale (Optional[Tensor]): Optional per row scale for activation Tensor, if present,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove "per row", it should work with any scale where the shapes are broadcastable to the original tensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the variable name should make it clear that this scale is not related to the scale which happens inside of activation quantization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK sg

@@ -36,30 +36,54 @@ class Int4Tensor(TorchAOBaseTensor):
dtype is the same as the original Tensor dtype
zero_point: (K/group_size, N) for 2D Tensor, (B, K/group_size, N) for 3D Tensor, where B is batch size,
dtype is the same as the original Tensor dtype
act_scale (Optional[Tensor]): Optional per row scale for activation Tensor, if present,
we'll multiply activation Tensor with act_scale before applying dynamic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does this compose with dynamic quant? Is there a faster path where we only scale once, instead of scaling by act_scale, then finding max, then scaling again by the inner scale?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will first do act_scale before dynamic quantization, and then do the dynamic quantization. we don't have the faster path right now, not sure if this is needed or whether this optimization will actually yield speedup yet

Summary:
We want to remove the extra wrapper tensor subclass `to_weight_tensor_with_linear_activation_scale_metadata`,
althopugh it is composable with different tensor subclasses, it complicates the flow and reduces locality of code
(compute logic spread between the LinearActivationScale wrapper and the real tensor subclasses)

Instead we want to implement activation scaling in the tensor itself, we added a `ActivationScalingMixin` that
can be inherited by the tensor subclass, that needs to have activation_scale attribute defined
* has `act_scale: Optional[Tensor]` argument defined
* in linear we can get the act_scale and do `input_tensor = input_tensor / self.act_scale` before calling
gemm kernels

Note for BC: we'll add later after this PR, right now just focuing on functionalities, no officialcheckpoint is released yet for Int4Tensor
or Int4Tensor with act_scale

Test Plan:
python test/prototype/test_awq.py

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants