diff --git a/docs/source/en/optimization/fp16.md b/docs/source/en/optimization/fp16.md index 734f63e68d23..edbb14fae3d4 100644 --- a/docs/source/en/optimization/fp16.md +++ b/docs/source/en/optimization/fp16.md @@ -174,39 +174,36 @@ Feel free to open an issue if dynamic compilation doesn't work as expected for a ### Regional compilation +[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by only compiling the *small and frequently-repeated block(s)* of a model - typically a transformer layer - and enables reusing compiled artifacts for every subsequent occurrence. +For many diffusion architectures, this delivers the same runtime speedups as full-graph compilation and reduces compile time by 8–10x. -[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence. -For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**. - -To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable: +Use the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `torch.compile`, on any component such as the transformer model as shown below. ```py # pip install -U diffusers import torch from diffusers import StableDiffusionXLPipeline -pipe = StableDiffusionXLPipeline.from_pretrained( +pipeline = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, ).to("cuda") -# Compile only the repeated Transformer layers inside the UNet -pipe.unet.compile_repeated_blocks(fullgraph=True) +# compile only the repeated transformer layers inside the UNet +pipeline.unet.compile_repeated_blocks(fullgraph=True) ``` -To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled: - +To enable regional compilation for a new model, add a `_repeated_blocks` attribute to a model class containing the class names (as strings) of the blocks you want to compile. ```py class MyUNet(ModelMixin): _repeated_blocks = ("Transformer2DModel",) # ← compiled by default ``` -For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705). - -**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags. - +> [!TIP] +> For more regional compilation examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705). +There is also a [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method in [Accelerate](https://huggingface.co/docs/accelerate/index) that automatically selects candidate blocks in a model to compile. The remaining graph is compiled separately. This is useful for quick experiments because there aren't as many options for you to set which blocks to compile or adjust compilation flags. ```py # pip install -U accelerate @@ -219,8 +216,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained( ).to("cuda") pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True) ``` -`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users. +[`~ModelMixin.compile_repeated_blocks`] is intentionally explicit. List the blocks to repeat in `_repeated_blocks` and the helper only compiles those blocks. It offers predictable behavior and easy reasoning about cache reuse in one line of code. ### Graph breaks @@ -296,3 +293,9 @@ An input is projected into three subspaces, represented by the projection matric ```py pipeline.fuse_qkv_projections() ``` + +## Resources + +- Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup using recipes from [flux-fast](https://github.com/huggingface/flux-fast). + + These recipes support AMD hardware and [Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev). \ No newline at end of file