Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch

from invokeai.backend.flux.modules.layers import MLPEmbedder
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import (
CustomModuleMixin,
)


class CustomFluxMLPEmbedder(MLPEmbedder, CustomModuleMixin):
def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor:
# Example patch logic: apply LoRA weights to in_layer and out_layer
for patch, patch_weight in self._patches_and_weights:
if hasattr(patch, "lora_up"):
if hasattr(self.in_layer, "weight"):
self.in_layer.weight.data += patch.lora_up.weight.data * patch_weight
if hasattr(self.out_layer, "weight"):
self.out_layer.weight.data += patch.lora_up.weight.data * patch_weight
# Move weights to input device
device = x.device
if hasattr(self.in_layer, "weight"):
self.in_layer.weight.data = cast_to_device(self.in_layer.weight, device)
if hasattr(self.out_layer, "weight"):
self.out_layer.weight.data = cast_to_device(self.out_layer.weight, device)
return super().forward(x)

def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor:
# Move weights to input device
device = x.device
if hasattr(self.in_layer, "weight"):
self.in_layer.weight.data = cast_to_device(self.in_layer.weight, device)
if hasattr(self.out_layer, "weight"):
self.out_layer.weight.data = cast_to_device(self.out_layer.weight, device)
return super().forward(x)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.get_num_patches() > 0:
return self._autocast_forward_with_patches(x)
elif self._device_autocasting_enabled:
return self._autocast_forward(x)
else:
return super().forward(x)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from invokeai.backend.flux.modules.layers import RMSNorm
from invokeai.backend.flux.modules.layers import MLPEmbedder, RMSNorm
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import (
CustomConv1d,
)
Expand All @@ -12,6 +12,9 @@
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import (
CustomEmbedding,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_mlp_embedder import (
CustomFluxMLPEmbedder,
)
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import (
CustomFluxRMSNorm,
)
Expand All @@ -32,6 +35,7 @@
torch.nn.GroupNorm: CustomGroupNorm,
torch.nn.Embedding: CustomEmbedding,
RMSNorm: CustomFluxRMSNorm,
MLPEmbedder: CustomFluxMLPEmbedder,
}

try:
Expand Down Expand Up @@ -66,6 +70,10 @@
# TODO(ryand): In the future, we may want to do a shallow copy of the __dict__.
custom_layer.__dict__ = module_to_wrap.__dict__

# Explicitly re-register parameters to ensure named_parameters() works correctly.
for name, param in module_to_wrap.named_parameters(recurse=False):
custom_layer.register_parameter(name, param)

# Initialize the CustomModuleMixin fields.
CustomModuleMixin.__init__(custom_layer) # type: ignore
return custom_layer
Expand All @@ -91,6 +99,8 @@
# TODO(ryand): In the future, we should manage this flag on a per-module basis.
custom_layer.set_device_autocasting_enabled(device_autocasting_enabled)
setattr(module, name, custom_layer)
# Recursively apply to the newly wrapped custom layer's children
apply_custom_layers_to_model(custom_layer, device_autocasting_enabled)
else:
# Recursively apply to submodules
apply_custom_layers_to_model(submodule, device_autocasting_enabled)
Expand All @@ -102,4 +112,4 @@
if override_type is not None:
setattr(module, name, unwrap_custom_layer(submodule, override_type))
else:
remove_custom_layers_from_model(submodule)
remove_custom_layers_from_model(submodule)

Check failure on line 115 in invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (W292)

invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py:115:55: W292 No newline at end of file
29 changes: 21 additions & 8 deletions invokeai/backend/patches/layer_patcher.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import re
from contextlib import contextmanager
from typing import Dict, Iterable, Optional, Tuple

import torch

from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
from invokeai.backend.flux.modules.layers import MLPEmbedder
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_mlp_embedder import (
CustomFluxMLPEmbedder,
)

Check failure on line 17 in invokeai/backend/patches/layer_patcher.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (I001)

invokeai/backend/patches/layer_patcher.py:1:1: I001 Import block is un-sorted or un-formatted


class LayerPatcher:
Expand All @@ -28,8 +32,7 @@
suppress_warning_layers: Optional[re.Pattern] = None,
):
"""Apply 'smart' model patching that chooses whether to use direct patching or a sidecar wrapper for each
module.
"""
module."""

# original_weights are stored for unpatching layers that are directly patched.
original_weights = OriginalWeightsStorage(cached_weights)
Expand Down Expand Up @@ -77,8 +80,7 @@
suppress_warning_layers: Optional[re.Pattern] = None,
):
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
patching or a sidecar wrapper for each module.
"""
patching or a sidecar wrapper for each module."""
if patch_weight == 0:
return

Expand Down Expand Up @@ -124,11 +126,17 @@
use_sidecar_patching = False
elif force_sidecar_patching:
use_sidecar_patching = True
# elif not hasattr(module, "get_num_patches"):
# continue
elif module.get_num_patches() > 0:
use_sidecar_patching = True
elif LayerPatcher._is_any_part_of_layer_on_cpu(module):
use_sidecar_patching = True

# Force sidecar patching for MLPEmbedder and CustomFluxMLPEmbedder
if isinstance(module, (MLPEmbedder, CustomFluxMLPEmbedder)):
use_sidecar_patching = True

if use_sidecar_patching:
LayerPatcher._apply_model_layer_wrapper_patch(
module_to_patch=module,
Expand Down Expand Up @@ -174,9 +182,12 @@

# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
for param_name, param_weight in patch.get_parameters(
dict(module_to_patch.named_parameters(recurse=False)), weight=patch_weight
).items():
# Manually create orig_parameters to bypass named_parameters() issue.
orig_parameters = {"weight": module_to_patch.weight}
if module_to_patch.bias is not None:
orig_parameters["bias"] = module_to_patch.bias

for param_name, param_weight in patch.get_parameters(orig_parameters, weight=patch_weight).items():
param_key = module_to_patch_key + "." + param_name
module_param = module_to_patch.get_parameter(param_name)

Expand Down Expand Up @@ -250,7 +261,9 @@

@staticmethod
def _get_submodule(
model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool
model: torch.nn.Module,
layer_key: str,
layer_key_is_flattened: bool,
) -> tuple[str, torch.nn.Module]:
"""Get the submodule corresponding to the given layer key.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.util.logging import InvokeAILogger


def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool:
Expand Down Expand Up @@ -239,10 +240,19 @@ def add_qkv_lora_layer_if_present(
)

# Final layer.
# Hyper FLUX LoRA support: patch norm_out.linear if present
# add_lora_layer_if_present("norm_out.linear", "norm_out.linear")
add_lora_layer_if_present("norm_out.linear", "final_layer.adaLN_modulation.1")

add_lora_layer_if_present("proj_out", "final_layer.linear")

# Assert that all keys were processed.
assert len(grouped_state_dict) == 0
if len(grouped_state_dict) > 0:
logger = InvokeAILogger.get_logger()
logger.warning(
f"The following unexpected LoRA layers were not loaded: {list(grouped_state_dict.keys())}."
" This is not necessarily a problem, but the LoRA may not be fully applied."
)

layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}

Expand Down
Loading