-
Notifications
You must be signed in to change notification settings - Fork 6.1k
enable torchao test cases on XPU and switch to device agnostic APIs for test cases #11654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2e11423
1967c62
5ed1810
007eb7f
71c328d
08e8038
d7fda4b
549eecf
c1e5792
db3cd69
9da1ad0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -300,9 +300,7 @@ def require_torch_gpu(test_case): | |
|
||
def require_torch_cuda_compatibility(expected_compute_capability): | ||
def decorator(test_case): | ||
if not torch.cuda.is_available(): | ||
return unittest.skip(test_case) | ||
else: | ||
if torch.cuda.is_available(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only check cuda device compatibility, if non-cuda device, just pass. For non-cuda device which needs compatibility, should check by themselves. |
||
current_compute_capability = get_torch_cuda_device_capability() | ||
return unittest.skipUnless( | ||
float(current_compute_capability) == float(expected_compute_capability), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,13 +30,15 @@ | |
) | ||
from diffusers.models.attention_processor import Attention | ||
from diffusers.utils.testing_utils import ( | ||
backend_empty_cache, | ||
backend_synchronize, | ||
enable_full_determinism, | ||
is_torch_available, | ||
is_torchao_available, | ||
nightly, | ||
numpy_cosine_similarity_distance, | ||
require_torch, | ||
require_torch_gpu, | ||
require_torch_accelerator, | ||
require_torchao_version_greater_or_equal, | ||
slow, | ||
torch_device, | ||
|
@@ -61,7 +63,7 @@ | |
|
||
|
||
@require_torch | ||
@require_torch_gpu | ||
@require_torch_accelerator | ||
@require_torchao_version_greater_or_equal("0.7.0") | ||
class TorchAoConfigTest(unittest.TestCase): | ||
def test_to_dict(self): | ||
|
@@ -79,7 +81,7 @@ def test_post_init_check(self): | |
Test kwargs validations in TorchAoConfig | ||
""" | ||
_ = TorchAoConfig("int4_weight_only") | ||
with self.assertRaisesRegex(ValueError, "is not supported yet"): | ||
with self.assertRaisesRegex(ValueError, "is not supported"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. xpu's error message doesn't have "yet", so just match "is not supported" for both CUDA and XPU There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. IIRC this test is failing in our cuda CI too since the error message does not "yet" any more, so this should be okay |
||
_ = TorchAoConfig("uint8") | ||
|
||
with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"): | ||
|
@@ -119,12 +121,12 @@ def test_repr(self): | |
|
||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners | ||
@require_torch | ||
@require_torch_gpu | ||
@require_torch_accelerator | ||
@require_torchao_version_greater_or_equal("0.7.0") | ||
class TorchAoTest(unittest.TestCase): | ||
def tearDown(self): | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
backend_empty_cache(torch_device) | ||
|
||
def get_dummy_components( | ||
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe" | ||
|
@@ -269,6 +271,7 @@ def test_int4wo_quant_bfloat16_conversion(self): | |
subfolder="transformer", | ||
quantization_config=quantization_config, | ||
torch_dtype=torch.bfloat16, | ||
device_map=f"{torch_device}:0", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if load model to CPU, will meet "NotImplementedError: Could not run 'aten::_convert_weight_to_int4pack' with arguments from the 'CPU' backend." error, this both happens in XPU and CUDA, so directly load model to accelerator here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, we recently got this error on our CI too |
||
) | ||
|
||
weight = quantized_model.transformer_blocks[0].ff.net[2].weight | ||
|
@@ -338,7 +341,7 @@ def test_device_map(self): | |
|
||
output = quantized_model(**inputs)[0] | ||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy() | ||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) | ||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3) | ||
|
||
with tempfile.TemporaryDirectory() as offload_folder: | ||
quantization_config = TorchAoConfig("int4_weight_only", group_size=64) | ||
|
@@ -359,7 +362,7 @@ def test_device_map(self): | |
|
||
output = quantized_model(**inputs)[0] | ||
output_slice = output.flatten()[-9:].detach().float().cpu().numpy() | ||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) | ||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3) | ||
|
||
def test_modules_to_not_convert(self): | ||
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) | ||
|
@@ -518,14 +521,14 @@ def test_sequential_cpu_offload(self): | |
|
||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners | ||
@require_torch | ||
@require_torch_gpu | ||
@require_torch_accelerator | ||
@require_torchao_version_greater_or_equal("0.7.0") | ||
class TorchAoSerializationTest(unittest.TestCase): | ||
model_name = "hf-internal-testing/tiny-flux-pipe" | ||
|
||
def tearDown(self): | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
backend_empty_cache(torch_device) | ||
|
||
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None): | ||
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs) | ||
|
@@ -593,17 +596,17 @@ def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, | |
) | ||
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) | ||
|
||
def test_int_a8w8_cuda(self): | ||
def test_int_a8w8_accelerator(self): | ||
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} | ||
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) | ||
device = "cuda" | ||
device = torch_device | ||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) | ||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) | ||
|
||
def test_int_a16w8_cuda(self): | ||
def test_int_a16w8_accelerator(self): | ||
quant_method, quant_method_kwargs = "int8_weight_only", {} | ||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) | ||
device = "cuda" | ||
device = torch_device | ||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) | ||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) | ||
|
||
|
@@ -624,14 +627,14 @@ def test_int_a16w8_cpu(self): | |
|
||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners | ||
@require_torch | ||
@require_torch_gpu | ||
@require_torch_accelerator | ||
@require_torchao_version_greater_or_equal("0.7.0") | ||
@slow | ||
@nightly | ||
class SlowTorchAoTests(unittest.TestCase): | ||
def tearDown(self): | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
backend_empty_cache(torch_device) | ||
|
||
def get_dummy_components(self, quantization_config: TorchAoConfig): | ||
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing | ||
|
@@ -713,8 +716,8 @@ def test_quantization(self): | |
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"]) | ||
self._test_quant_type(quantization_config, expected_slice) | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
torch.cuda.synchronize() | ||
backend_empty_cache(torch_device) | ||
backend_synchronize(torch_device) | ||
|
||
def test_serialization_int8wo(self): | ||
quantization_config = TorchAoConfig("int8wo") | ||
|
@@ -733,8 +736,8 @@ def test_serialization_int8wo(self): | |
pipe.remove_all_hooks() | ||
del pipe.transformer | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
torch.cuda.synchronize() | ||
backend_empty_cache(torch_device) | ||
backend_synchronize(torch_device) | ||
transformer = FluxTransformer2DModel.from_pretrained( | ||
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False | ||
) | ||
|
@@ -783,14 +786,14 @@ def test_memory_footprint_int8wo(self): | |
|
||
|
||
@require_torch | ||
@require_torch_gpu | ||
@require_torch_accelerator | ||
@require_torchao_version_greater_or_equal("0.7.0") | ||
@slow | ||
@nightly | ||
class SlowTorchAoPreserializedModelTests(unittest.TestCase): | ||
def tearDown(self): | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
backend_empty_cache(torch_device) | ||
|
||
def get_dummy_inputs(self, device: torch.device, seed: int = 0): | ||
if str(device).startswith("mps"): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only check device capability when it's cuda; for non-cuda device, should check in separate utilities. In this case, non-cuda device(like XPU)'s case will be skipped by original implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably still raise an error if torchao is being used with mps or other devices, otherwise it leads to an obscure error somewhere deep in the code that common users will not understand
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@a-r-r-o-w , i enhanced this utility per your comments, pls help review again, thx.