Skip to content

Commit 6adb8b8

Browse files
authored
Move float8_opaque_tensor to prototype (#3365)
1 parent b55713a commit 6adb8b8

File tree

8 files changed

+173
-116
lines changed

8 files changed

+173
-116
lines changed
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
)
1616

1717
from torchao import quantize_
18+
from torchao.prototype.float8_opaque_tensor import (
19+
Float8DynamicActivationFloat8WeightOpaqueTensorConfig,
20+
)
1821
from torchao.quantization import (
19-
Float8DynamicActivationFloat8WeightConfig,
2022
PerGroup,
2123
PerRow,
2224
PerTensor,
@@ -29,10 +31,8 @@
2931

3032

3133
def get_config(granularity):
32-
return Float8DynamicActivationFloat8WeightConfig(
33-
activation_dtype=torch.float8_e4m3fn,
34+
return Float8DynamicActivationFloat8WeightOpaqueTensorConfig(
3435
granularity=granularity,
35-
float8_packing_format="opaque",
3636
)
3737

3838

@@ -133,7 +133,7 @@ def test_module_path(self, dtype):
133133
quantize_(linear, get_config(PerRow()))
134134
self.assertEqual(
135135
str(type(linear.weight)),
136-
"<class 'torchao.quantization.Float8OpaqueTensor'>",
136+
"<class 'torchao.prototype.float8_opaque_tensor.Float8OpaqueTensor'>",
137137
)
138138

139139
with tempfile.NamedTemporaryFile() as f:
@@ -142,7 +142,7 @@ def test_module_path(self, dtype):
142142
state_dict = torch.load(f)
143143
self.assertEqual(
144144
str(type(state_dict["weight"])),
145-
"<class 'torchao.quantization.Float8OpaqueTensor'>",
145+
"<class 'torchao.prototype.float8_opaque_tensor.Float8OpaqueTensor'>",
146146
)
147147

148148

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .float8_opaque_tensor import Float8OpaqueTensor
2+
from .inference_workflow import Float8DynamicActivationFloat8WeightOpaqueTensorConfig
3+
4+
__all__ = [
5+
"Float8OpaqueTensor",
6+
"Float8DynamicActivationFloat8WeightOpaqueTensorConfig",
7+
]

torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py renamed to torchao/prototype/float8_opaque_tensor/float8_opaque_tensor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@
2121
from torchao.quantization.quantize_.common import (
2222
_choose_quant_func_and_quantize_tensor,
2323
)
24+
from torchao.quantization.quantize_.workflows.float8.float8_tensor import (
25+
QuantizeTensorToFloat8Kwargs,
26+
)
2427
from torchao.utils import (
2528
TorchAOBaseTensor,
2629
)
2730

28-
from .float8_tensor import QuantizeTensorToFloat8Kwargs
29-
3031
__all__ = [
3132
"Float8OpaqueTensor",
3233
]
@@ -267,7 +268,7 @@ def _(func, types, args, kwargs):
267268
return y
268269

269270

270-
Float8OpaqueTensor.__module__ = "torchao.quantization"
271+
Float8OpaqueTensor.__module__ = "torchao.prototype.float8_opaque_tensor"
271272

272273
# Allow a model with Float8OpaqueTensor weights to be loaded with `weights_only=True`
273274
torch.serialization.add_safe_globals([Float8OpaqueTensor])
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from typing import TYPE_CHECKING, List, Optional, Union
9+
10+
import torch
11+
12+
import torchao
13+
from torchao.core.config import AOBaseConfig
14+
15+
if TYPE_CHECKING:
16+
from torchao.quantization.granularity import PerGroup, PerRow, PerTensor
17+
18+
19+
# Define FP8Granularity type alias to break circular import dependencies
20+
FP8Granularity = Union["PerTensor", "PerRow", "PerGroup"]
21+
22+
import types
23+
from functools import partial
24+
25+
from torchao.quantization.quant_api import _module_extra_repr
26+
from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs
27+
from torchao.quantization.transform_module import (
28+
register_quantize_module_handler,
29+
)
30+
from torchao.quantization.utils import get_block_size
31+
32+
from .float8_opaque_tensor import Float8OpaqueTensor
33+
34+
35+
@dataclass
36+
class Float8DynamicActivationFloat8WeightOpaqueTensorConfig(AOBaseConfig):
37+
"""
38+
Configuration for applying float8 dynamic symmetric quantization to both activations and weights of linear layers.
39+
40+
Args:
41+
activation_dtype (torch.dtype): The target data type for activation quantization. Only torch.float8_e4m3fn supported.
42+
weight_dtype (torch.dtype): The target data type for weight quantization. Only torch.float8_e4m3fn supported.
43+
granularity (Optional[Union[FP8Granularity, List[FP8Granularity]]]):
44+
The granularity for quantization. Can be either a single granularity (applied to both
45+
activations and weights) or a tuple of two granularities (one for activations, one for weights).
46+
If None, defaults to PerTensor for both. Currently both quantizations need to be the same type. And
47+
only PerTensor/PerRow/PerGroup are supported.
48+
49+
"""
50+
51+
activation_dtype: torch.dtype = torch.float8_e4m3fn
52+
weight_dtype: torch.dtype = torch.float8_e4m3fn
53+
granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None
54+
set_inductor_config: bool = True
55+
56+
def __post_init__(self):
57+
torch._C._log_api_usage_once(
58+
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
59+
)
60+
activation_granularity, weight_granularity = (
61+
Float8OpaqueTensor._normalize_and_check_granularity(self.granularity)
62+
)
63+
self.granularity = [activation_granularity, weight_granularity]
64+
65+
66+
def _float8_dynamic_activation_float8_weight_opaque_tensor_quantize(weight, config):
67+
activation_dtype = config.activation_dtype
68+
granularity = config.granularity
69+
70+
activation_granularity, weight_granularity = granularity
71+
72+
act_quant_kwargs = QuantizeTensorToFloat8Kwargs(
73+
activation_dtype,
74+
activation_granularity,
75+
)
76+
77+
block_size = get_block_size(weight.shape, weight_granularity)
78+
quantized_weight = Float8OpaqueTensor.from_hp(
79+
weight,
80+
block_size=block_size,
81+
act_quant_kwargs=act_quant_kwargs,
82+
)
83+
84+
return quantized_weight
85+
86+
87+
@register_quantize_module_handler(Float8DynamicActivationFloat8WeightOpaqueTensorConfig)
88+
def _float8_dynamic_activation_float8_weight_opaque_tensor_transform(
89+
module: torch.nn.Module,
90+
config: Float8DynamicActivationFloat8WeightOpaqueTensorConfig,
91+
*,
92+
parameter_name: str = "weight",
93+
):
94+
if config.set_inductor_config:
95+
torchao.quantization.utils.recommended_inductor_config_setter()
96+
97+
assert hasattr(module, parameter_name), (
98+
f"applying float8 dynamic activation quant requires module to have parameter {parameter_name} attribute"
99+
+ f" but {module} does not have one"
100+
)
101+
quantized_tensor = _float8_dynamic_activation_float8_weight_opaque_tensor_quantize(
102+
getattr(module, parameter_name), config
103+
)
104+
setattr(
105+
module,
106+
parameter_name,
107+
torch.nn.Parameter(quantized_tensor, requires_grad=False),
108+
)
109+
module.extra_repr = types.MethodType(
110+
partial(
111+
_module_extra_repr,
112+
original_extra_repr=module.extra_repr,
113+
parameter_name=parameter_name,
114+
),
115+
module,
116+
)
117+
return module

torchao/quantization/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@
9292
quantize_affine,
9393
)
9494
from .quantize_.workflows import (
95-
Float8OpaqueTensor,
9695
Float8Tensor,
9796
Int4MarlinSparseTensor,
9897
Int4OpaqueTensor,
@@ -175,7 +174,6 @@
175174
"Int4TilePackedTo4dTensor",
176175
"Float8Tensor",
177176
"Int4OpaqueTensor",
178-
"Float8OpaqueTensor",
179177
# smooth quant - subject to change
180178
"get_scale",
181179
"SmoothFakeDynQuantMixin",

torchao/quantization/quant_api.py

Lines changed: 39 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,6 @@
7474
KernelPreference,
7575
)
7676
from torchao.quantization.quantize_.workflows import (
77-
Float8OpaqueTensor,
78-
Float8PackingFormat,
7977
Float8Tensor,
8078
Int4ChooseQParamsAlgorithm,
8179
Int4MarlinSparseTensor,
@@ -1808,23 +1806,14 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
18081806
kernel_preference: KernelPreference = KernelPreference.AUTO
18091807
set_inductor_config: bool = True
18101808
version: int = 2
1811-
float8_packing_format: Float8PackingFormat = Float8PackingFormat.PLAIN
18121809

18131810
def __post_init__(self):
18141811
torch._C._log_api_usage_once(
18151812
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
18161813
)
1817-
if (
1818-
self.version == 2
1819-
and self.float8_packing_format == Float8PackingFormat.OPAQUE
1820-
):
1821-
activation_granularity, weight_granularity = (
1822-
Float8OpaqueTensor._normalize_and_check_granularity(self.granularity)
1823-
)
1824-
else:
1825-
activation_granularity, weight_granularity = _normalize_granularity(
1826-
self.granularity
1827-
)
1814+
activation_granularity, weight_granularity = _normalize_granularity(
1815+
self.granularity
1816+
)
18281817
self.granularity = [activation_granularity, weight_granularity]
18291818

18301819
default_use_fast_accum = True
@@ -1854,48 +1843,43 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18541843
activation_value_lb = config.activation_value_lb
18551844
activation_value_ub = config.activation_value_ub
18561845
kernel_preference = config.kernel_preference
1857-
float8_packing_format = config.float8_packing_format
18581846

18591847
# Ensure works on device
1848+
_check_hardware_support(granularity)
18601849
activation_granularity, weight_granularity = granularity
18611850

1862-
if float8_packing_format == Float8PackingFormat.PLAIN:
1863-
# Note: right now we assume it's weights of conv2d and conv3d purely based
1864-
# on the dimension of weight, currently there is no conflict with linear 2d
1865-
# and moe weights 3d
1866-
# if we need to support conv1d, which also has 3d weight, we may have to
1867-
# pass around the module as well to distinguish between conv1d and 3d moe weight
1868-
if weight.dim() in [4, 5]:
1869-
# weights for conv2d or 3d
1870-
assert isinstance(activation_granularity, PerTensor) and isinstance(
1871-
weight_granularity, PerTensor
1872-
), (
1873-
"4D/5D tensor only supports per tensor activation and weight quantization"
1874-
)
1875-
1876-
# conv3d weight dim: (C_out, C_in, K1, K2, K3)
1877-
# conv2d weight dim: (C_out, C_in, K1, K2)
1878-
# skip quantization when either C_out or C_in
1879-
# is not a multiple of 16
1880-
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
1881-
return weight
1882-
1883-
elif not _fp8_mm_compat(weight):
1884-
# TODO(future PR): this should really throw an exception instead of silently
1885-
# not doing what the user asked
1851+
# Note: right now we assume it's weights of conv2d and conv3d purely based
1852+
# on the dimension of weight, currently there is no conflict with linear 2d
1853+
# and moe weights 3d
1854+
# if we need to support conv1d, which also has 3d weight, we may have to
1855+
# pass around the module as well to distinguish between conv1d and 3d moe weight
1856+
if weight.dim() in [4, 5]:
1857+
# weights for conv2d or 3d
1858+
assert isinstance(activation_granularity, PerTensor) and isinstance(
1859+
weight_granularity, PerTensor
1860+
), "4D/5D tensor only supports per tensor activation and weight quantization"
1861+
1862+
# conv3d weight dim: (C_out, C_in, K1, K2, K3)
1863+
# conv2d weight dim: (C_out, C_in, K1, K2)
1864+
# skip quantization when either C_out or C_in
1865+
# is not a multiple of 16
1866+
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
18861867
return weight
1868+
elif not _fp8_mm_compat(weight):
1869+
# TODO(future PR): this should really throw an exception instead of silently
1870+
# not doing what the user asked
1871+
return weight
18871872

1888-
if isinstance(weight_granularity, PerRow):
1889-
assert weight.dtype == torch.bfloat16, (
1890-
"PerRow quantization only works for bfloat16 precision input weight"
1891-
)
1873+
if isinstance(weight_granularity, PerRow):
1874+
assert weight.dtype == torch.bfloat16, (
1875+
"PerRow quantization only works for bfloat16 precision input weight"
1876+
)
18921877

18931878
if config.version == 1:
18941879
warnings.warn(
18951880
"Config Deprecation: version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details"
18961881
)
18971882

1898-
_check_hardware_support(granularity)
18991883
block_size = get_block_size(weight.shape[-2:], weight_granularity)
19001884
if weight.dim() == 3:
19011885
block_size = tuple([1] + list(block_size))
@@ -1926,26 +1910,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
19261910
kernel_preference=kernel_preference,
19271911
)
19281912

1929-
if float8_packing_format == Float8PackingFormat.PLAIN:
1930-
quantized_weight = Float8Tensor.from_hp(
1931-
weight,
1932-
float8_dtype=weight_dtype,
1933-
granularity=weight_granularity,
1934-
mm_config=mm_config,
1935-
kernel_preference=kernel_preference,
1936-
act_quant_kwargs=act_quant_kwargs,
1937-
)
1938-
elif float8_packing_format == Float8PackingFormat.OPAQUE:
1939-
block_size = get_block_size(weight.shape, weight_granularity)
1940-
quantized_weight = Float8OpaqueTensor.from_hp(
1941-
weight,
1942-
block_size=block_size,
1943-
act_quant_kwargs=act_quant_kwargs,
1944-
)
1945-
else:
1946-
raise ValueError(
1947-
f"Unsupported float8 packing format: {float8_packing_format}"
1948-
)
1913+
quantized_weight = Float8Tensor.from_hp(
1914+
weight,
1915+
float8_dtype=weight_dtype,
1916+
granularity=weight_granularity,
1917+
mm_config=mm_config,
1918+
kernel_preference=kernel_preference,
1919+
act_quant_kwargs=act_quant_kwargs,
1920+
)
19491921

19501922
return quantized_weight
19511923

@@ -1957,10 +1929,9 @@ def _float8_dynamic_activation_float8_weight_transform(
19571929
*,
19581930
parameter_name: str = "weight",
19591931
):
1960-
if config.float8_packing_format == Float8PackingFormat.PLAIN:
1961-
assert is_sm_at_least_89() or is_MI300(), (
1962-
"Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
1963-
)
1932+
assert is_sm_at_least_89() or is_MI300(), (
1933+
"Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
1934+
)
19641935
if config.set_inductor_config:
19651936
torchao.quantization.utils.recommended_inductor_config_setter()
19661937

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
from .float8.float8_opaque_tensor import (
2-
Float8OpaqueTensor,
3-
)
4-
from .float8.float8_packing_format import Float8PackingFormat
51
from .float8.float8_tensor import (
62
Float8Tensor,
73
QuantizeTensorToFloat8Kwargs,
@@ -41,9 +37,7 @@
4137
"Int4MarlinSparseTensor",
4238
"Int4PlainInt32Tensor",
4339
"Int4TilePackedTo4dTensor",
44-
"Float8OpaqueTensor",
4540
"Float8Tensor",
46-
"Float8PackingFormat",
4741
"QuantizeTensorToFloat8Kwargs",
4842
"Int4OpaqueTensor",
4943
"Int4ChooseQParamsAlgorithm",

0 commit comments

Comments
 (0)