Skip to content

[fp8] Only assert when CUDA is available. #2590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 0 additions & 28 deletions torchao/float8/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@
PerRow,
PerTensor,
)
from torchao.utils import (
is_MI300,
is_sm_at_least_89,
)

Tensor = torch.Tensor

Expand Down Expand Up @@ -147,27 +143,3 @@ def _normalize_granularity(
f"Invalid granularity specification: {granularity}, only PerTensor or PerRow are supported."
)
return processed_granularity


def _check_hardware_support(
granularities: Tuple[FP8Granularity, FP8Granularity],
) -> None:
"""
Validate that the hardware supports the requested granularities.

Args:
granularities: Tuple of (activation_granularity, weight_granularity)

Raises:
AssertionError: If hardware doesn't support the requested granularity
ValueError: If invalid granularity type is provided
"""
for _granularity in granularities:
if not isinstance(_granularity, (PerTensor, PerRow)):
raise ValueError(
f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported."
)

assert is_sm_at_least_89() or is_MI300(), (
"Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change will impact other device, such as cpu or any unknow device. Suggest to add common utlis function to judge the fp8 capability and apply it to all fp8 related changes. This function should only ensure the CUDA compute capability ≥8.9 or MI300+ and XPU device is available now.

54 changes: 45 additions & 9 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
from torchao.float8.inference import (
Float8MMConfig,
FP8Granularity,
_check_hardware_support,
_normalize_granularity,
)
from torchao.quantization.linear_activation_weight_observed_tensor import (
Expand Down Expand Up @@ -270,6 +269,46 @@ def change_linear_weights_to_int4_woqtensors(
########


def _check_hardware_support(
config: Union[
"Float8DynamicActivationFloat8WeightConfig",
"Float8StaticActivationFloat8WeightConfig",
],
) -> None:
"""
Validate that the hardware supports the given float8 quantization configuration.

Args:
config: Must be one of Float8DynamicActivationFloat8WeightConfig or Float8StaticActivationFloat8WeightConfig.

Raises:
AssertionError: If hardware doesn't support the float8 feature.
ValueError: If invalid granularity type is provided
TypeError: If config is not of the correct type.
"""
if not isinstance(
config,
(
Float8DynamicActivationFloat8WeightConfig,
Float8StaticActivationFloat8WeightConfig,
),
):
raise TypeError(
f"config must be one of Float8DynamicActivationFloat8WeightConfig or Float8StaticActivationFloat8WeightConfig, got {type(config)}"
)
# XPU by default supports float8 by simulation, thus always passes.
if not torch.xpu.is_available():
assert is_sm_at_least_89() or is_MI300(), (
f"{config.__class__.__name__} requires CUDA compute capability ≥8.9 or MI300+."
)
if config is Float8DynamicActivationFloat8WeightConfig:
for _granularity in config.granularity:
if not isinstance(_granularity, (PerTensor, PerRow)):
raise ValueError(
f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported."
)


def _replace_with_custom_fn_if_matches_filter(
model,
replacement_fn,
Expand Down Expand Up @@ -1633,7 +1672,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
mm_config = config.mm_config

# Ensure works on device
_check_hardware_support(granularity)
_check_hardware_support(config)
activation_granularity, weight_granularity = granularity

if not _fp8_mm_compat(weight):
Expand Down Expand Up @@ -1672,9 +1711,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
def _float8_dynamic_activation_float8_weight_transform(
module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig
):
assert is_sm_at_least_89() or is_MI300(), (
"Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
)
_check_hardware_support(config)
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

Expand Down Expand Up @@ -1710,7 +1747,8 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
module: torch.nn.Module, config: Float8DynamicActivationFloat8SemiSparseWeightConfig
):
assert is_sm_at_least_90(), "Float8 quantization is only supported on CUDA>=9.0"
if torch.cuda.is_available():
assert is_sm_at_least_90(), "Float8 quantization is only supported on CUDA>=9.0"

weight = module.weight
weight_dtype = config.weight_dtype
Expand Down Expand Up @@ -1769,9 +1807,7 @@ def __post_init__(self):
def _float8_static_activation_float8_weight_transform(
module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
):
assert is_sm_at_least_89() or is_MI300(), (
"Float8 static activation quantization is only supported on CUDA 8.9 and above"
)
_check_hardware_support(config)

scale = config.scale
activation_dtype = config.activation_dtype
Expand Down