7474 KernelPreference ,
7575)
7676from torchao .quantization .quantize_ .workflows import (
77- Float8OpaqueTensor ,
78- Float8PackingFormat ,
7977 Float8Tensor ,
8078 Int4ChooseQParamsAlgorithm ,
8179 Int4MarlinSparseTensor ,
@@ -1808,23 +1806,14 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
18081806 kernel_preference : KernelPreference = KernelPreference .AUTO
18091807 set_inductor_config : bool = True
18101808 version : int = 2
1811- float8_packing_format : Float8PackingFormat = Float8PackingFormat .PLAIN
18121809
18131810 def __post_init__ (self ):
18141811 torch ._C ._log_api_usage_once (
18151812 "torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
18161813 )
1817- if (
1818- self .version == 2
1819- and self .float8_packing_format == Float8PackingFormat .OPAQUE
1820- ):
1821- activation_granularity , weight_granularity = (
1822- Float8OpaqueTensor ._normalize_and_check_granularity (self .granularity )
1823- )
1824- else :
1825- activation_granularity , weight_granularity = _normalize_granularity (
1826- self .granularity
1827- )
1814+ activation_granularity , weight_granularity = _normalize_granularity (
1815+ self .granularity
1816+ )
18281817 self .granularity = [activation_granularity , weight_granularity ]
18291818
18301819 default_use_fast_accum = True
@@ -1854,48 +1843,43 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18541843 activation_value_lb = config .activation_value_lb
18551844 activation_value_ub = config .activation_value_ub
18561845 kernel_preference = config .kernel_preference
1857- float8_packing_format = config .float8_packing_format
18581846
18591847 # Ensure works on device
1848+ _check_hardware_support (granularity )
18601849 activation_granularity , weight_granularity = granularity
18611850
1862- if float8_packing_format == Float8PackingFormat .PLAIN :
1863- # Note: right now we assume it's weights of conv2d and conv3d purely based
1864- # on the dimension of weight, currently there is no conflict with linear 2d
1865- # and moe weights 3d
1866- # if we need to support conv1d, which also has 3d weight, we may have to
1867- # pass around the module as well to distinguish between conv1d and 3d moe weight
1868- if weight .dim () in [4 , 5 ]:
1869- # weights for conv2d or 3d
1870- assert isinstance (activation_granularity , PerTensor ) and isinstance (
1871- weight_granularity , PerTensor
1872- ), (
1873- "4D/5D tensor only supports per tensor activation and weight quantization"
1874- )
1875-
1876- # conv3d weight dim: (C_out, C_in, K1, K2, K3)
1877- # conv2d weight dim: (C_out, C_in, K1, K2)
1878- # skip quantization when either C_out or C_in
1879- # is not a multiple of 16
1880- if weight .shape [0 ] % 16 != 0 or weight .shape [1 ] % 16 != 0 :
1881- return weight
1882-
1883- elif not _fp8_mm_compat (weight ):
1884- # TODO(future PR): this should really throw an exception instead of silently
1885- # not doing what the user asked
1851+ # Note: right now we assume it's weights of conv2d and conv3d purely based
1852+ # on the dimension of weight, currently there is no conflict with linear 2d
1853+ # and moe weights 3d
1854+ # if we need to support conv1d, which also has 3d weight, we may have to
1855+ # pass around the module as well to distinguish between conv1d and 3d moe weight
1856+ if weight .dim () in [4 , 5 ]:
1857+ # weights for conv2d or 3d
1858+ assert isinstance (activation_granularity , PerTensor ) and isinstance (
1859+ weight_granularity , PerTensor
1860+ ), "4D/5D tensor only supports per tensor activation and weight quantization"
1861+
1862+ # conv3d weight dim: (C_out, C_in, K1, K2, K3)
1863+ # conv2d weight dim: (C_out, C_in, K1, K2)
1864+ # skip quantization when either C_out or C_in
1865+ # is not a multiple of 16
1866+ if weight .shape [0 ] % 16 != 0 or weight .shape [1 ] % 16 != 0 :
18861867 return weight
1868+ elif not _fp8_mm_compat (weight ):
1869+ # TODO(future PR): this should really throw an exception instead of silently
1870+ # not doing what the user asked
1871+ return weight
18871872
1888- if isinstance (weight_granularity , PerRow ):
1889- assert weight .dtype == torch .bfloat16 , (
1890- "PerRow quantization only works for bfloat16 precision input weight"
1891- )
1873+ if isinstance (weight_granularity , PerRow ):
1874+ assert weight .dtype == torch .bfloat16 , (
1875+ "PerRow quantization only works for bfloat16 precision input weight"
1876+ )
18921877
18931878 if config .version == 1 :
18941879 warnings .warn (
18951880 "Config Deprecation: version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details"
18961881 )
18971882
1898- _check_hardware_support (granularity )
18991883 block_size = get_block_size (weight .shape [- 2 :], weight_granularity )
19001884 if weight .dim () == 3 :
19011885 block_size = tuple ([1 ] + list (block_size ))
@@ -1926,26 +1910,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
19261910 kernel_preference = kernel_preference ,
19271911 )
19281912
1929- if float8_packing_format == Float8PackingFormat .PLAIN :
1930- quantized_weight = Float8Tensor .from_hp (
1931- weight ,
1932- float8_dtype = weight_dtype ,
1933- granularity = weight_granularity ,
1934- mm_config = mm_config ,
1935- kernel_preference = kernel_preference ,
1936- act_quant_kwargs = act_quant_kwargs ,
1937- )
1938- elif float8_packing_format == Float8PackingFormat .OPAQUE :
1939- block_size = get_block_size (weight .shape , weight_granularity )
1940- quantized_weight = Float8OpaqueTensor .from_hp (
1941- weight ,
1942- block_size = block_size ,
1943- act_quant_kwargs = act_quant_kwargs ,
1944- )
1945- else :
1946- raise ValueError (
1947- f"Unsupported float8 packing format: { float8_packing_format } "
1948- )
1913+ quantized_weight = Float8Tensor .from_hp (
1914+ weight ,
1915+ float8_dtype = weight_dtype ,
1916+ granularity = weight_granularity ,
1917+ mm_config = mm_config ,
1918+ kernel_preference = kernel_preference ,
1919+ act_quant_kwargs = act_quant_kwargs ,
1920+ )
19491921
19501922 return quantized_weight
19511923
@@ -1957,10 +1929,9 @@ def _float8_dynamic_activation_float8_weight_transform(
19571929 * ,
19581930 parameter_name : str = "weight" ,
19591931):
1960- if config .float8_packing_format == Float8PackingFormat .PLAIN :
1961- assert is_sm_at_least_89 () or is_MI300 (), (
1962- "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
1963- )
1932+ assert is_sm_at_least_89 () or is_MI300 (), (
1933+ "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
1934+ )
19641935 if config .set_inductor_config :
19651936 torchao .quantization .utils .recommended_inductor_config_setter ()
19661937
0 commit comments