Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,20 @@ def _quantization_scheme_map_from_config(
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.


# Cache whether quant_format requires activation quantization
activation_quant_supported = is_activation_quantization_format(quant_format)
float_quant_type = QuantizationType.FLOAT

# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
# quantized, a weights key indicating how the weights are quantized,
# and a list of targets under the `targets` key, dictating which
# layers are impacted by the quantization details. The quantization
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.

config_groups = config.get("config_groups", dict())
for _, quant_config in config_groups.items():
targets = quant_config.get("targets")
Expand All @@ -212,7 +226,7 @@ def _quantization_scheme_map_from_config(
)

target_scheme_map[target]["input_activations"] = None
if is_activation_quantization_format(quant_format):
if activation_quant_supported:
input_activations = quant_config.get("input_activations")
# The only case where we have activation quant supported
# but no input_activations provided in the config
Expand All @@ -221,12 +235,12 @@ def _quantization_scheme_map_from_config(
if not input_activations:
assert (
target_scheme_map[target]["weights"].type
== QuantizationType.FLOAT
== float_quant_type
)
else:
target_scheme_map[target]["input_activations"] = (
QuantizationArgs.model_validate( # noqa: E501
quant_config.get("input_activations")
input_activations
)
)
return target_scheme_map
Expand Down Expand Up @@ -348,15 +362,25 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool
def _is_wNa16_group_channel(
self, weight_quant: BaseModel, input_quant: BaseModel
) -> bool:
input_quant_none = input_quant is None
is_symmetric = weight_quant.symmetric
is_channel_group = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
or weight_quant.strategy == QuantizationStrategy.GROUP.value
)
is_static = not weight_quant.dynamic
# Cache attribute and value lookups to local variables to speed up
# Avoid repeated lookups for strategies by caching both .value's
wq_strategy = weight_quant.strategy
channel_value = QuantizationStrategy.CHANNEL.value
group_value = QuantizationStrategy.GROUP.value

# Unroll condition for better short-circuiting
if not ((wq_strategy == channel_value) or (wq_strategy == group_value)):
return False

return is_channel_group and input_quant_none and is_symmetric and is_static
# Only check other conditions if the above condition is passed
if input_quant is not None:
return False
if not weight_quant.symmetric:
return False
if weight_quant.dynamic:
return False

return True

def _get_scheme_from_parts(
self, weight_quant: BaseModel, input_quant: BaseModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from compressed_tensors import CompressionFormat
from torch.nn import Module

_ACTIVATION_QUANTIZATION_FORMATS = {
CompressionFormat.naive_quantized.value,
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value,
}


def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
CompressionFormat.naive_quantized.value,
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value,
]
return format in _ACTIVATION_QUANTIZATION_FORMATS


Expand Down