Skip to content

Commit 33e636c

Browse files
authored
enable torchao test cases on XPU and switch to device agnostic APIs for test cases (#11654)
* enable torchao cases on XPU Signed-off-by: Matrix YAO <[email protected]> * device agnostic APIs Signed-off-by: YAO Matrix <[email protected]> * more Signed-off-by: YAO Matrix <[email protected]> * fix style Signed-off-by: YAO Matrix <[email protected]> * enable test_torch_compile_recompilation_and_graph_break on XPU Signed-off-by: YAO Matrix <[email protected]> * resolve comments Signed-off-by: YAO Matrix <[email protected]> --------- Signed-off-by: Matrix YAO <[email protected]> Signed-off-by: YAO Matrix <[email protected]>
1 parent e27142a commit 33e636c

30 files changed

+109
-91
lines changed

src/diffusers/quantizers/quantization_config.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]]
493493
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
494494
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
495495
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
496-
if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9():
496+
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
497497
raise ValueError(
498498
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
499499
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
@@ -645,7 +645,7 @@ def generate_fpx_quantization_types(bits: int):
645645
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
646646
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
647647

648-
if cls._is_cuda_capability_atleast_8_9():
648+
if cls._is_xpu_or_cuda_capability_atleast_8_9():
649649
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
650650

651651
return QUANTIZATION_TYPES
@@ -655,14 +655,16 @@ def generate_fpx_quantization_types(bits: int):
655655
)
656656

657657
@staticmethod
658-
def _is_cuda_capability_atleast_8_9() -> bool:
659-
if not torch.cuda.is_available():
660-
raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.")
661-
662-
major, minor = torch.cuda.get_device_capability()
663-
if major == 8:
664-
return minor >= 9
665-
return major >= 9
658+
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
659+
if torch.cuda.is_available():
660+
major, minor = torch.cuda.get_device_capability()
661+
if major == 8:
662+
return minor >= 9
663+
return major >= 9
664+
elif torch.xpu.is_available():
665+
return True
666+
else:
667+
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
666668

667669
def get_apply_tensor_subclass(self):
668670
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()

src/diffusers/utils/testing_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,7 @@ def require_torch_gpu(test_case):
300300

301301
def require_torch_cuda_compatibility(expected_compute_capability):
302302
def decorator(test_case):
303-
if not torch.cuda.is_available():
304-
return unittest.skip(test_case)
305-
else:
303+
if torch.cuda.is_available():
306304
current_compute_capability = get_torch_cuda_device_capability()
307305
return unittest.skipUnless(
308306
float(current_compute_capability) == float(expected_compute_capability),

tests/models/autoencoders/test_models_consistency_decoder_vae.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline
2323
from diffusers.utils.testing_utils import (
24+
backend_empty_cache,
2425
enable_full_determinism,
2526
load_image,
2627
slow,
@@ -162,13 +163,13 @@ def setUp(self):
162163
# clean up the VRAM before each test
163164
super().setUp()
164165
gc.collect()
165-
torch.cuda.empty_cache()
166+
backend_empty_cache(torch_device)
166167

167168
def tearDown(self):
168169
# clean up the VRAM after each test
169170
super().tearDown()
170171
gc.collect()
171-
torch.cuda.empty_cache()
172+
backend_empty_cache(torch_device)
172173

173174
@torch.no_grad()
174175
def test_encode_decode(self):

tests/models/unets/test_models_unet_2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from diffusers import UNet2DModel
2323
from diffusers.utils import logging
2424
from diffusers.utils.testing_utils import (
25+
backend_empty_cache,
2526
enable_full_determinism,
2627
floats_tensor,
2728
require_torch_accelerator,
@@ -229,7 +230,7 @@ def test_from_pretrained_accelerate_wont_change_results(self):
229230

230231
# two models don't need to stay in the device at the same time
231232
del model_accelerate
232-
torch.cuda.empty_cache()
233+
backend_empty_cache(torch_device)
233234
gc.collect()
234235

235236
model_normal_load, _ = UNet2DModel.from_pretrained(

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
require_peft_backend,
4747
require_torch_accelerator,
4848
require_torch_accelerator_with_fp16,
49-
require_torch_gpu,
5049
skip_mps,
5150
slow,
5251
torch_all_close,
@@ -978,13 +977,13 @@ def test_ip_adapter_plus(self):
978977
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
979978
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
980979

981-
@require_torch_gpu
982980
@parameterized.expand(
983981
[
984982
("hf-internal-testing/unet2d-sharded-dummy", None),
985983
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
986984
]
987985
)
986+
@require_torch_accelerator
988987
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
989988
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
990989
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
@@ -994,13 +993,13 @@ def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
994993
assert loaded_model
995994
assert new_output.sample.shape == (4, 4, 16, 16)
996995

997-
@require_torch_gpu
998996
@parameterized.expand(
999997
[
1000998
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
1001999
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
10021000
]
10031001
)
1002+
@require_torch_accelerator
10041003
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
10051004
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10061005
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)

tests/pipelines/allegro/test_allegro.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler
2626
from diffusers.utils.testing_utils import (
27+
backend_empty_cache,
2728
enable_full_determinism,
2829
numpy_cosine_similarity_distance,
2930
require_hf_hub_version_greater,
@@ -341,12 +342,12 @@ class AllegroPipelineIntegrationTests(unittest.TestCase):
341342
def setUp(self):
342343
super().setUp()
343344
gc.collect()
344-
torch.cuda.empty_cache()
345+
backend_empty_cache(torch_device)
345346

346347
def tearDown(self):
347348
super().tearDown()
348349
gc.collect()
349-
torch.cuda.empty_cache()
350+
backend_empty_cache(torch_device)
350351

351352
def test_allegro(self):
352353
generator = torch.Generator("cpu").manual_seed(0)

tests/pipelines/audioldm/test_audioldm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
UNet2DConditionModel,
3838
)
3939
from diffusers.utils import is_xformers_available
40-
from diffusers.utils.testing_utils import enable_full_determinism, nightly, torch_device
40+
from diffusers.utils.testing_utils import backend_empty_cache, enable_full_determinism, nightly, torch_device
4141

4242
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
4343
from ..test_pipelines_common import PipelineTesterMixin
@@ -378,12 +378,12 @@ class AudioLDMPipelineSlowTests(unittest.TestCase):
378378
def setUp(self):
379379
super().setUp()
380380
gc.collect()
381-
torch.cuda.empty_cache()
381+
backend_empty_cache(torch_device)
382382

383383
def tearDown(self):
384384
super().tearDown()
385385
gc.collect()
386-
torch.cuda.empty_cache()
386+
backend_empty_cache(torch_device)
387387

388388
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
389389
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -423,12 +423,12 @@ class AudioLDMPipelineNightlyTests(unittest.TestCase):
423423
def setUp(self):
424424
super().setUp()
425425
gc.collect()
426-
torch.cuda.empty_cache()
426+
backend_empty_cache(torch_device)
427427

428428
def tearDown(self):
429429
super().tearDown()
430430
gc.collect()
431-
torch.cuda.empty_cache()
431+
backend_empty_cache(torch_device)
432432

433433
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
434434
generator = torch.Generator(device=generator_device).manual_seed(seed)

tests/pipelines/audioldm2/test_audioldm2.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@
4545
LMSDiscreteScheduler,
4646
PNDMScheduler,
4747
)
48-
from diffusers.utils.testing_utils import enable_full_determinism, is_torch_version, nightly, torch_device
48+
from diffusers.utils.testing_utils import (
49+
backend_empty_cache,
50+
enable_full_determinism,
51+
is_torch_version,
52+
nightly,
53+
torch_device,
54+
)
4955

5056
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
5157
from ..test_pipelines_common import PipelineTesterMixin
@@ -540,12 +546,12 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase):
540546
def setUp(self):
541547
super().setUp()
542548
gc.collect()
543-
torch.cuda.empty_cache()
549+
backend_empty_cache(torch_device)
544550

545551
def tearDown(self):
546552
super().tearDown()
547553
gc.collect()
548-
torch.cuda.empty_cache()
554+
backend_empty_cache(torch_device)
549555

550556
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
551557
generator = torch.Generator(device=generator_device).manual_seed(seed)

tests/pipelines/cogvideo/test_cogvideox.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
2424
from diffusers.utils.testing_utils import (
25+
backend_empty_cache,
2526
enable_full_determinism,
2627
numpy_cosine_similarity_distance,
2728
require_torch_accelerator,
@@ -334,12 +335,12 @@ class CogVideoXPipelineIntegrationTests(unittest.TestCase):
334335
def setUp(self):
335336
super().setUp()
336337
gc.collect()
337-
torch.cuda.empty_cache()
338+
backend_empty_cache(torch_device)
338339

339340
def tearDown(self):
340341
super().tearDown()
341342
gc.collect()
342-
torch.cuda.empty_cache()
343+
backend_empty_cache(torch_device)
343344

344345
def test_cogvideox(self):
345346
generator = torch.Generator("cpu").manual_seed(0)

tests/pipelines/cogview3/test_cogview3plus.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
2424
from diffusers.utils.testing_utils import (
25+
backend_empty_cache,
2526
enable_full_determinism,
2627
numpy_cosine_similarity_distance,
2728
require_torch_accelerator,
@@ -244,12 +245,12 @@ class CogView3PlusPipelineIntegrationTests(unittest.TestCase):
244245
def setUp(self):
245246
super().setUp()
246247
gc.collect()
247-
torch.cuda.empty_cache()
248+
backend_empty_cache(torch_device)
248249

249250
def tearDown(self):
250251
super().tearDown()
251252
gc.collect()
252-
torch.cuda.empty_cache()
253+
backend_empty_cache(torch_device)
253254

254255
def test_cogview3plus(self):
255256
generator = torch.Generator("cpu").manual_seed(0)

tests/pipelines/controlnet/test_controlnet_img2img.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from diffusers.utils import load_image
3737
from diffusers.utils.import_utils import is_xformers_available
3838
from diffusers.utils.testing_utils import (
39+
backend_empty_cache,
3940
enable_full_determinism,
4041
floats_tensor,
4142
load_numpy,
@@ -412,12 +413,12 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
412413
def setUp(self):
413414
super().setUp()
414415
gc.collect()
415-
torch.cuda.empty_cache()
416+
backend_empty_cache(torch_device)
416417

417418
def tearDown(self):
418419
super().tearDown()
419420
gc.collect()
420-
torch.cuda.empty_cache()
421+
backend_empty_cache(torch_device)
421422

422423
def test_canny(self):
423424
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")

tests/pipelines/controlnet/test_controlnet_inpaint.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from diffusers.utils import load_image
3737
from diffusers.utils.import_utils import is_xformers_available
3838
from diffusers.utils.testing_utils import (
39+
backend_empty_cache,
3940
enable_full_determinism,
4041
floats_tensor,
4142
load_numpy,
@@ -464,12 +465,12 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
464465
def setUp(self):
465466
super().setUp()
466467
gc.collect()
467-
torch.cuda.empty_cache()
468+
backend_empty_cache(torch_device)
468469

469470
def tearDown(self):
470471
super().tearDown()
471472
gc.collect()
472-
torch.cuda.empty_cache()
473+
backend_empty_cache(torch_device)
473474

474475
def test_canny(self):
475476
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")

tests/pipelines/controlnet_sd3/test_controlnet_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def test_xformers_attention_forwardGenerator_pass(self):
221221

222222
@slow
223223
@require_big_accelerator
224-
@pytest.mark.big_gpu_with_torch_cuda
224+
@pytest.mark.big_accelerator
225225
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
226226
pipeline_class = StableDiffusion3ControlNetPipeline
227227

tests/pipelines/deepfloyd_if/test_if.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from diffusers.utils.import_utils import is_xformers_available
2626
from diffusers.utils.testing_utils import (
2727
backend_empty_cache,
28+
backend_max_memory_allocated,
2829
backend_reset_max_memory_allocated,
2930
backend_reset_peak_memory_stats,
3031
load_numpy,
@@ -135,7 +136,7 @@ def test_if_text_to_image(self):
135136

136137
image = output.images[0]
137138

138-
mem_bytes = torch.cuda.max_memory_allocated()
139+
mem_bytes = backend_max_memory_allocated(torch_device)
139140
assert mem_bytes < 12 * 10**9
140141

141142
expected_image = load_numpy(

tests/pipelines/deepfloyd_if/test_if_img2img.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from diffusers.utils.import_utils import is_xformers_available
2525
from diffusers.utils.testing_utils import (
2626
backend_empty_cache,
27+
backend_max_memory_allocated,
2728
backend_reset_max_memory_allocated,
2829
backend_reset_peak_memory_stats,
2930
floats_tensor,
@@ -151,7 +152,7 @@ def test_if_img2img(self):
151152
)
152153
image = output.images[0]
153154

154-
mem_bytes = torch.cuda.max_memory_allocated()
155+
mem_bytes = backend_max_memory_allocated(torch_device)
155156
assert mem_bytes < 12 * 10**9
156157

157158
expected_image = load_numpy(

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def test_flux_true_cfg(self):
224224

225225
@nightly
226226
@require_big_accelerator
227-
@pytest.mark.big_gpu_with_torch_cuda
227+
@pytest.mark.big_accelerator
228228
class FluxPipelineSlowTests(unittest.TestCase):
229229
pipeline_class = FluxPipeline
230230
repo_id = "black-forest-labs/FLUX.1-schnell"
@@ -312,7 +312,7 @@ def test_flux_inference(self):
312312

313313
@slow
314314
@require_big_accelerator
315-
@pytest.mark.big_gpu_with_torch_cuda
315+
@pytest.mark.big_accelerator
316316
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
317317
pipeline_class = FluxPipeline
318318
repo_id = "black-forest-labs/FLUX.1-dev"

tests/pipelines/flux/test_pipeline_flux_redux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
@slow
2121
@require_big_accelerator
22-
@pytest.mark.big_gpu_with_torch_cuda
22+
@pytest.mark.big_accelerator
2323
class FluxReduxSlowTests(unittest.TestCase):
2424
pipeline_class = FluxPriorReduxPipeline
2525
repo_id = "black-forest-labs/FLUX.1-Redux-dev"

0 commit comments

Comments
 (0)