Skip to content

Commit 4bac6ac

Browse files
authored
Merge branch 'main' into hidream-torch-compile
2 parents f9662ed + f4fa3be commit 4bac6ac

File tree

3 files changed

+42
-4
lines changed

3 files changed

+42
-4
lines changed

src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,39 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
607607

608608
return latents
609609

610+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
611+
def enable_vae_slicing(self):
612+
r"""
613+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
614+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
615+
"""
616+
self.vae.enable_slicing()
617+
618+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
619+
def disable_vae_slicing(self):
620+
r"""
621+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
622+
computing decoding in one step.
623+
"""
624+
self.vae.disable_slicing()
625+
626+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
627+
def enable_vae_tiling(self):
628+
r"""
629+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
630+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
631+
processing larger images.
632+
"""
633+
self.vae.enable_tiling()
634+
635+
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
636+
def disable_vae_tiling(self):
637+
r"""
638+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
639+
computing decoding in one step.
640+
"""
641+
self.vae.disable_tiling()
642+
610643
def prepare_latents(
611644
self,
612645
image,

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@
5353
torch_device,
5454
)
5555

56-
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, UNetTesterMixin
56+
from ..test_modeling_common import (
57+
LoraHotSwappingForModelTesterMixin,
58+
ModelTesterMixin,
59+
TorchCompileTesterMixin,
60+
UNetTesterMixin,
61+
)
5762

5863

5964
if is_peft_available():
@@ -351,7 +356,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
351356

352357

353358
class UNet2DConditionModelTests(
354-
ModelTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
359+
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
355360
):
356361
model_class = UNet2DConditionModel
357362
main_input_name = "sample"

tests/pipelines/ltx/test_ltx_image2video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def get_dummy_inputs(self, device, seed=0):
109109
else:
110110
generator = torch.Generator(device=device).manual_seed(seed)
111111

112-
image = torch.randn((1, 3, 32, 32), generator=generator, device=device)
112+
image = torch.rand((1, 3, 32, 32), generator=generator, device=device)
113113

114114
inputs = {
115115
"image": image,
@@ -142,7 +142,7 @@ def test_inference(self):
142142

143143
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
144144
expected_video = torch.randn(9, 3, 32, 32)
145-
max_diff = np.abs(generated_video - expected_video).max()
145+
max_diff = torch.amax(torch.abs(generated_video - expected_video))
146146
self.assertLessEqual(max_diff, 1e10)
147147

148148
def test_callback_inputs(self):

0 commit comments

Comments
 (0)