diff --git a/docs/source/en/optimization/fp16.md b/docs/source/en/optimization/fp16.md index 2e12bfadcf5c..45a2282ba10e 100644 --- a/docs/source/en/optimization/fp16.md +++ b/docs/source/en/optimization/fp16.md @@ -152,9 +152,39 @@ Compilation is slow the first time, but once compiled, it is significantly faste ### Regional compilation -[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks. -[Accelerate](https://huggingface.co/docs/accelerate/index) provides the [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method for automatically compiling the repeated blocks of a `nn.Module` sequentially. The rest of the model is compiled separately. +[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence. +For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**. + +To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable: + +```py +# pip install -U diffusers +import torch +from diffusers import StableDiffusionXLPipeline + +pipe = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16, +).to("cuda") + +# Compile only the repeated Transformer layers inside the UNet +pipe.unet.compile_repeated_blocks(fullgraph=True) +``` + +To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled: + + +```py +class MyUNet(ModelMixin): + _repeated_blocks = ("Transformer2DModel",) # ← compiled by default +``` + +For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705). + +**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags. + + ```py # pip install -U accelerate @@ -167,6 +197,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained( ).to("cuda") pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True) ``` +`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users. + ### Graph breaks @@ -241,4 +273,4 @@ An input is projected into three subspaces, represented by the projection matric ```py pipeline.fuse_qkv_projections() -``` \ No newline at end of file +``` diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 5fa04fb2606f..8e1ec5f55889 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _keep_in_fp32_modules = None _skip_layerwise_casting_patterns = None _supports_group_offloading = True + _repeated_blocks = [] def __init__(self): super().__init__() @@ -1404,6 +1405,39 @@ def float(self, *args): else: return super().float(*args) + def compile_repeated_blocks(self, *args, **kwargs): + """ + Compiles *only* the frequently repeated sub-modules of a model (e.g. the Transformer layers) instead of + compiling the entire model. This technique—often called **regional compilation** (see the PyTorch recipe + https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) can reduce end-to-end compile time + substantially, while preserving the runtime speed-ups you would expect from a full `torch.compile`. + + The set of sub-modules to compile is discovered by the presence of **`_repeated_blocks`** attribute in the + model definition. Define this attribute on your model subclass as a list/tuple of class names (strings). Every + module whose class name matches will be compiled. + + Once discovered, each matching sub-module is compiled by calling `submodule.compile(*args, **kwargs)`. Any + positional or keyword arguments you supply to `compile_repeated_blocks` are forwarded verbatim to + `torch.compile`. + """ + repeated_blocks = getattr(self, "_repeated_blocks", None) + + if not repeated_blocks: + raise ValueError( + "`_repeated_blocks` attribute is empty. " + f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. " + ) + has_compiled_region = False + for submod in self.modules(): + if submod.__class__.__name__ in repeated_blocks: + submod.compile(*args, **kwargs) + has_compiled_region = True + + if not has_compiled_region: + raise ValueError( + f"Regional compilation failed because {repeated_blocks} classes are not found in the model. " + ) + @classmethod def _load_pretrained_model( cls, diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index d11f6c2a5e25..0f6dd677ac5c 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -407,6 +407,7 @@ class ChromaTransformer2DModel( _supports_gradient_checkpointing = True _no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"] + _repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] @register_to_config diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index ab579a0eb531..3af1de2ad0be 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -227,6 +227,7 @@ class FluxTransformer2DModel( _supports_gradient_checkpointing = True _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index c48c586a28de..6944a6c536b5 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -870,6 +870,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, "HunyuanVideoPatchEmbed", "HunyuanVideoTokenRefiner", ] + _repeated_blocks = [ + "HunyuanVideoTransformerBlock", + "HunyuanVideoSingleTransformerBlock", + "HunyuanVideoPatchEmbed", + "HunyuanVideoTokenRefiner", + ] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 38b7b6af50f9..2d06124282d1 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["norm"] + _repeated_blocks = ["LTXVideoTransformerBlock"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index baa0ede4184e..0ae7f2c00d92 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi _no_split_modules = ["WanTransformerBlock"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _repeated_blocks = ["WanTransformerBlock"] @register_to_config def __init__( diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 0cf5133c5405..0f789d3961fc 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -167,6 +167,7 @@ class conditioning with `class_embed_type` equal to `None`. _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] _skip_layerwise_casting_patterns = ["norm"] + _repeated_blocks = ["BasicTransformerBlock"] @register_to_config def __init__( diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 3a401c46fb5e..7e1e1483f7e0 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1935,6 +1935,27 @@ def test_torch_compile_recompilation_and_graph_break(self): _ = model(**inputs_dict) _ = model(**inputs_dict) + def test_torch_compile_repeated_blocks(self): + if self.model_class._repeated_blocks is None: + pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.") + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict).to(torch_device) + model.compile_repeated_blocks(fullgraph=True) + + recompile_limit = 1 + if self.model_class.__name__ == "UNet2DConditionModel": + recompile_limit = 2 + + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(recompile_limit=recompile_limit), + torch.no_grad(), + ): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + def test_compile_with_group_offloading(self): torch._dynamo.config.cache_size_limit = 10000