diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index bdb8920a399e..c5497d1c8d16 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -866,15 +866,17 @@ def test_fp4_double_safe(self): @require_torch_version_greater("2.7.1") class Bnb4BitCompileTests(QuantCompileTests): - quantization_config = PipelineQuantizationConfig( - quant_backend="bitsandbytes_8bit", - quant_kwargs={ - "load_in_4bit": True, - "bnb_4bit_quant_type": "nf4", - "bnb_4bit_compute_dtype": torch.bfloat16, - }, - components_to_quantize=["transformer", "text_encoder_2"], - ) + @property + def quantization_config(self): + return PipelineQuantizationConfig( + quant_backend="bitsandbytes_8bit", + quant_kwargs={ + "load_in_4bit": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_compute_dtype": torch.bfloat16, + }, + components_to_quantize=["transformer", "text_encoder_2"], + ) def test_torch_compile(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True @@ -883,5 +885,7 @@ def test_torch_compile(self): def test_torch_compile_with_cpu_offload(self): super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) - def test_torch_compile_with_group_offload(self): - super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config) + def test_torch_compile_with_group_offload_leaf(self): + super()._test_torch_compile_with_group_offload_leaf( + quantization_config=self.quantization_config, use_stream=True + ) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index d048b0b7db46..383cdd6849ea 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -831,11 +831,13 @@ def test_serialization_sharded(self): @require_torch_version_greater_equal("2.6.0") class Bnb8BitCompileTests(QuantCompileTests): - quantization_config = PipelineQuantizationConfig( - quant_backend="bitsandbytes_8bit", - quant_kwargs={"load_in_8bit": True}, - components_to_quantize=["transformer", "text_encoder_2"], - ) + @property + def quantization_config(self): + return PipelineQuantizationConfig( + quant_backend="bitsandbytes_8bit", + quant_kwargs={"load_in_8bit": True}, + components_to_quantize=["transformer", "text_encoder_2"], + ) def test_torch_compile(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True @@ -847,7 +849,7 @@ def test_torch_compile_with_cpu_offload(self): ) @pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.") - def test_torch_compile_with_group_offload(self): - super()._test_torch_compile_with_group_offload( - quantization_config=self.quantization_config, torch_dtype=torch.float16 + def test_torch_compile_with_group_offload_leaf(self): + super()._test_torch_compile_with_group_offload_leaf( + quantization_config=self.quantization_config, torch_dtype=torch.float16, use_stream=True ) diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py index ba870ba733b9..99bb8980ef9f 100644 --- a/tests/quantization/test_torch_compile_utils.py +++ b/tests/quantization/test_torch_compile_utils.py @@ -24,7 +24,11 @@ @require_torch_gpu @slow class QuantCompileTests(unittest.TestCase): - quantization_config = None + @property + def quantization_config(self): + raise NotImplementedError( + "This property should be implemented in the subclass to return the appropriate quantization config." + ) def setUp(self): super().setUp() @@ -64,7 +68,9 @@ def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype= # small resolutions to ensure speedy execution. pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) - def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16): + def _test_torch_compile_with_group_offload_leaf( + self, quantization_config, torch_dtype=torch.bfloat16, *, use_stream: bool = False + ): torch._dynamo.config.cache_size_limit = 10000 pipe = self._init_pipeline(quantization_config, torch_dtype) @@ -72,8 +78,7 @@ def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtyp "onload_device": torch.device("cuda"), "offload_device": torch.device("cpu"), "offload_type": "leaf_level", - "use_stream": True, - "non_blocking": True, + "use_stream": use_stream, } pipe.transformer.enable_group_offload(**group_offload_kwargs) pipe.transformer.compile() diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 0741c7f87c78..c4cfc8eb87fb 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -19,6 +19,7 @@ from typing import List import numpy as np +from parameterized import parameterized from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel from diffusers import ( @@ -29,6 +30,7 @@ TorchAoConfig, ) from diffusers.models.attention_processor import Attention +from diffusers.quantizers import PipelineQuantizationConfig from diffusers.utils.testing_utils import ( backend_empty_cache, backend_synchronize, @@ -44,6 +46,8 @@ torch_device, ) +from ..test_torch_compile_utils import QuantCompileTests + enable_full_determinism() @@ -625,6 +629,53 @@ def test_int_a16w8_cpu(self): self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) +@require_torchao_version_greater_or_equal("0.7.0") +class TorchAoCompileTest(QuantCompileTests): + @property + def quantization_config(self): + return PipelineQuantizationConfig( + quant_mapping={ + "transformer": TorchAoConfig(quant_type="int8_weight_only"), + }, + ) + + def test_torch_compile(self): + super()._test_torch_compile(quantization_config=self.quantization_config) + + @unittest.skip( + "Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work " + "when compiling." + ) + def test_torch_compile_with_cpu_offload(self): + # RuntimeError: _apply(): Couldn't swap Linear.weight + super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) + + @unittest.skip( + """ + For `use_stream=False`: + - Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation + is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure. + For `use_stream=True`: + Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO. + """ + ) + @parameterized.expand([False, True]) + def test_torch_compile_with_group_offload_leaf(self): + # For use_stream=False: + # If we run group offloading without compilation, we will see: + # RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match. + # When running with compilation, the error ends up being different: + # Dynamo failed to run FX node with fake tensors: call_function (*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16, + # requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu') + # Looks like something that will have to be looked into upstream. + # for linear layers, weight.tensor_impl shows cuda... but: + # weight.tensor_impl.{data,scale,zero_point}.device will be cpu + + # For use_stream=True: + # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=, types=(,), arg_types=(,), kwarg_types={} + super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config) + + # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch @require_torch_accelerator