Skip to content

Commit b6f7933

Browse files
sayakpaulDN6
andauthored
[tests] tests for compilation + quantization (bnb) (#11672)
* start adding compilation tests for quantization. * fixes * make common utility. * modularize. * add group offloading+compile * xfail * update * Update tests/quantization/test_torch_compile_utils.py Co-authored-by: Dhruv Nair <[email protected]> * fixes --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent 33e636c commit b6f7933

File tree

4 files changed

+153
-0
lines changed

4 files changed

+153
-0
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,18 @@ def decorator(test_case):
291291
return decorator
292292

293293

294+
def require_torch_version_greater(torch_version):
295+
"""Decorator marking a test that requires torch with a specific version greater."""
296+
297+
def decorator(test_case):
298+
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
299+
return unittest.skipUnless(
300+
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
301+
)(test_case)
302+
303+
return decorator
304+
305+
294306
def require_torch_gpu(test_case):
295307
"""Decorator marking a test that requires CUDA and PyTorch."""
296308
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(

tests/quantization/bnb/test_4bit.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
FluxTransformer2DModel,
3131
SD3Transformer2DModel,
3232
)
33+
from diffusers.quantizers import PipelineQuantizationConfig
3334
from diffusers.utils import is_accelerate_version, logging
3435
from diffusers.utils.testing_utils import (
3536
CaptureLogger,
@@ -44,11 +45,14 @@
4445
require_peft_backend,
4546
require_torch,
4647
require_torch_accelerator,
48+
require_torch_version_greater,
4749
require_transformers_version_greater,
4850
slow,
4951
torch_device,
5052
)
5153

54+
from ..test_torch_compile_utils import QuantCompileTests
55+
5256

5357
def get_some_linear_layer(model):
5458
if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
@@ -855,3 +859,26 @@ def test_fp4_double_unsafe(self):
855859

856860
def test_fp4_double_safe(self):
857861
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
862+
863+
864+
@require_torch_version_greater("2.7.1")
865+
class Bnb4BitCompileTests(QuantCompileTests):
866+
quantization_config = PipelineQuantizationConfig(
867+
quant_backend="bitsandbytes_8bit",
868+
quant_kwargs={
869+
"load_in_4bit": True,
870+
"bnb_4bit_quant_type": "nf4",
871+
"bnb_4bit_compute_dtype": torch.bfloat16,
872+
},
873+
components_to_quantize=["transformer", "text_encoder_2"],
874+
)
875+
876+
def test_torch_compile(self):
877+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
878+
super()._test_torch_compile(quantization_config=self.quantization_config)
879+
880+
def test_torch_compile_with_cpu_offload(self):
881+
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
882+
883+
def test_torch_compile_with_group_offload(self):
884+
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,14 @@
4646
require_peft_version_greater,
4747
require_torch,
4848
require_torch_accelerator,
49+
require_torch_version_greater_equal,
4950
require_transformers_version_greater,
5051
slow,
5152
torch_device,
5253
)
5354

55+
from ..test_torch_compile_utils import QuantCompileTests
56+
5457

5558
def get_some_linear_layer(model):
5659
if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
@@ -821,3 +824,27 @@ def test_serialization_sharded(self):
821824
out_0 = self.model_0(**inputs)[0]
822825
out_1 = model_1(**inputs)[0]
823826
self.assertTrue(torch.equal(out_0, out_1))
827+
828+
829+
@require_torch_version_greater_equal("2.6.0")
830+
class Bnb8BitCompileTests(QuantCompileTests):
831+
quantization_config = PipelineQuantizationConfig(
832+
quant_backend="bitsandbytes_8bit",
833+
quant_kwargs={"load_in_8bit": True},
834+
components_to_quantize=["transformer", "text_encoder_2"],
835+
)
836+
837+
def test_torch_compile(self):
838+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
839+
super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
840+
841+
def test_torch_compile_with_cpu_offload(self):
842+
super()._test_torch_compile_with_cpu_offload(
843+
quantization_config=self.quantization_config, torch_dtype=torch.float16
844+
)
845+
846+
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
847+
def test_torch_compile_with_group_offload(self):
848+
super()._test_torch_compile_with_group_offload(
849+
quantization_config=self.quantization_config, torch_dtype=torch.float16
850+
)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Team Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a clone of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import gc
16+
import unittest
17+
18+
import torch
19+
20+
from diffusers import DiffusionPipeline
21+
from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device
22+
23+
24+
@require_torch_gpu
25+
@slow
26+
class QuantCompileTests(unittest.TestCase):
27+
quantization_config = None
28+
29+
def setUp(self):
30+
super().setUp()
31+
gc.collect()
32+
backend_empty_cache(torch_device)
33+
torch.compiler.reset()
34+
35+
def tearDown(self):
36+
super().tearDown()
37+
gc.collect()
38+
backend_empty_cache(torch_device)
39+
torch.compiler.reset()
40+
41+
def _init_pipeline(self, quantization_config, torch_dtype):
42+
pipe = DiffusionPipeline.from_pretrained(
43+
"stabilityai/stable-diffusion-3-medium-diffusers",
44+
quantization_config=quantization_config,
45+
torch_dtype=torch_dtype,
46+
)
47+
return pipe
48+
49+
def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
50+
pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
51+
# import to ensure fullgraph True
52+
pipe.transformer.compile(fullgraph=True)
53+
54+
for _ in range(2):
55+
# small resolutions to ensure speedy execution.
56+
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
57+
58+
def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
59+
pipe = self._init_pipeline(quantization_config, torch_dtype)
60+
pipe.enable_model_cpu_offload()
61+
pipe.transformer.compile()
62+
63+
for _ in range(2):
64+
# small resolutions to ensure speedy execution.
65+
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
66+
67+
def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):
68+
torch._dynamo.config.cache_size_limit = 10000
69+
70+
pipe = self._init_pipeline(quantization_config, torch_dtype)
71+
group_offload_kwargs = {
72+
"onload_device": torch.device("cuda"),
73+
"offload_device": torch.device("cpu"),
74+
"offload_type": "leaf_level",
75+
"use_stream": True,
76+
"non_blocking": True,
77+
}
78+
pipe.transformer.enable_group_offload(**group_offload_kwargs)
79+
pipe.transformer.compile()
80+
for name, component in pipe.components.items():
81+
if name != "transformer" and isinstance(component, torch.nn.Module):
82+
if torch.device(component.device).type == "cpu":
83+
component.to("cuda")
84+
85+
for _ in range(2):
86+
# small resolutions to ensure speedy execution.
87+
pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)

0 commit comments

Comments
 (0)