Skip to content

[rfc][compile] compile method for DiffusionPipeline #11705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_keep_in_fp32_modules = None
_skip_layerwise_casting_patterns = None
_supports_group_offloading = True
_regions_for_compile = []

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -1402,6 +1403,44 @@ def float(self, *args):
else:
return super().float(*args)

@wraps(torch.nn.Module.compile)
def compile(self, use_regional_compile: bool = True, *args, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we go with Option 3 - we should turn this to False to keep the existing behavior of pipe.tranformer.compile same as before.

""" """
if use_regional_compile:
regions_for_compile = getattr(self, "_regions_for_compile", None)

if not regions_for_compile:
logger.warning(
"_regions_for_compile attribute is empty. Using _no_split_modules to find compile regions."
)

regions_for_compile = getattr(self, "_no_split_modules", None)

if not regions_for_compile:
logger.warning(
"Both _regions_for_compile and _no_split_modules attribute are empty. "
"Set _regions_for_compile for the model to benefit from regional compilation. "
"Falling back to full model compilation, which could have high first iteration "
"latency."
)
super().compile(*args, **kwargs)

has_compiled_region = False
for submod in self.modules():
if submod.__class__.__name__ in regions_for_compile:
has_compiled_region = True
submod.compile(*args, **kwargs)

if not has_compiled_region:
raise ValueError(
f"Regional compilation failed because {regions_for_compile} classes are not found in the model. "
"Either set them correctly, or set `use_regional_compile` to False while calling copmile, e.g. "
"pipe.transformer.compile(use_regional_compile=False) to fallback to full model compilation, "
"which could have high iteration latency."
)
else:
super().compile(*args, **kwargs)

@classmethod
def _load_pretrained_model(
cls,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class FluxTransformer2DModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_regions_for_compile = _no_split_modules

@register_to_config
def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_regions_for_compile = _no_split_modules

@register_to_config
def __init__(
Expand Down