diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_flux_mlp_embedder.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_flux_mlp_embedder.py new file mode 100644 index 00000000000..99bddb40af4 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_flux_mlp_embedder.py @@ -0,0 +1,42 @@ +import torch + +from invokeai.backend.flux.modules.layers import MLPEmbedder +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) + + +class CustomFluxMLPEmbedder(MLPEmbedder, CustomModuleMixin): + def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor: + # Example patch logic: apply LoRA weights to in_layer and out_layer + for patch, patch_weight in self._patches_and_weights: + if hasattr(patch, "lora_up"): + if hasattr(self.in_layer, "weight"): + self.in_layer.weight.data += patch.lora_up.weight.data * patch_weight + if hasattr(self.out_layer, "weight"): + self.out_layer.weight.data += patch.lora_up.weight.data * patch_weight + # Move weights to input device + device = x.device + if hasattr(self.in_layer, "weight"): + self.in_layer.weight.data = cast_to_device(self.in_layer.weight, device) + if hasattr(self.out_layer, "weight"): + self.out_layer.weight.data = cast_to_device(self.out_layer.weight, device) + return super().forward(x) + + def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor: + # Move weights to input device + device = x.device + if hasattr(self.in_layer, "weight"): + self.in_layer.weight.data = cast_to_device(self.in_layer.weight, device) + if hasattr(self.out_layer, "weight"): + self.out_layer.weight.data = cast_to_device(self.out_layer.weight, device) + return super().forward(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.get_num_patches() > 0: + return self._autocast_forward_with_patches(x) + elif self._device_autocasting_enabled: + return self._autocast_forward(x) + else: + return super().forward(x) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py index 0e271eaec5a..2ea90a54999 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py @@ -2,7 +2,7 @@ import torch -from invokeai.backend.flux.modules.layers import RMSNorm +from invokeai.backend.flux.modules.layers import MLPEmbedder, RMSNorm from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import ( CustomConv1d, ) @@ -12,6 +12,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import ( CustomEmbedding, ) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_mlp_embedder import ( + CustomFluxMLPEmbedder, +) from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import ( CustomFluxRMSNorm, ) @@ -32,6 +35,7 @@ torch.nn.GroupNorm: CustomGroupNorm, torch.nn.Embedding: CustomEmbedding, RMSNorm: CustomFluxRMSNorm, + MLPEmbedder: CustomFluxMLPEmbedder, } try: @@ -66,6 +70,10 @@ def wrap_custom_layer(module_to_wrap: torch.nn.Module, custom_layer_type: type[T # TODO(ryand): In the future, we may want to do a shallow copy of the __dict__. custom_layer.__dict__ = module_to_wrap.__dict__ + # Explicitly re-register parameters to ensure named_parameters() works correctly. + for name, param in module_to_wrap.named_parameters(recurse=False): + custom_layer.register_parameter(name, param) + # Initialize the CustomModuleMixin fields. CustomModuleMixin.__init__(custom_layer) # type: ignore return custom_layer @@ -91,6 +99,8 @@ def apply_custom_layers_to_model(module: torch.nn.Module, device_autocasting_ena # TODO(ryand): In the future, we should manage this flag on a per-module basis. custom_layer.set_device_autocasting_enabled(device_autocasting_enabled) setattr(module, name, custom_layer) + # Recursively apply to the newly wrapped custom layer's children + apply_custom_layers_to_model(custom_layer, device_autocasting_enabled) else: # Recursively apply to submodules apply_custom_layers_to_model(submodule, device_autocasting_enabled) @@ -102,4 +112,4 @@ def remove_custom_layers_from_model(module: torch.nn.Module): if override_type is not None: setattr(module, name, unwrap_custom_layer(submodule, override_type)) else: - remove_custom_layers_from_model(submodule) + remove_custom_layers_from_model(submodule) \ No newline at end of file diff --git a/invokeai/backend/patches/layer_patcher.py b/invokeai/backend/patches/layer_patcher.py index 08a028c88fa..2206a983a71 100644 --- a/invokeai/backend/patches/layer_patcher.py +++ b/invokeai/backend/patches/layer_patcher.py @@ -11,6 +11,10 @@ from invokeai.backend.util import InvokeAILogger from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage +from invokeai.backend.flux.modules.layers import MLPEmbedder +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_mlp_embedder import ( + CustomFluxMLPEmbedder, +) class LayerPatcher: @@ -28,8 +32,7 @@ def apply_smart_model_patches( suppress_warning_layers: Optional[re.Pattern] = None, ): """Apply 'smart' model patching that chooses whether to use direct patching or a sidecar wrapper for each - module. - """ + module.""" # original_weights are stored for unpatching layers that are directly patched. original_weights = OriginalWeightsStorage(cached_weights) @@ -77,8 +80,7 @@ def apply_smart_model_patch( suppress_warning_layers: Optional[re.Pattern] = None, ): """Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct - patching or a sidecar wrapper for each module. - """ + patching or a sidecar wrapper for each module.""" if patch_weight == 0: return @@ -124,11 +126,17 @@ def apply_smart_model_patch( use_sidecar_patching = False elif force_sidecar_patching: use_sidecar_patching = True + # elif not hasattr(module, "get_num_patches"): + # continue elif module.get_num_patches() > 0: use_sidecar_patching = True elif LayerPatcher._is_any_part_of_layer_on_cpu(module): use_sidecar_patching = True + # Force sidecar patching for MLPEmbedder and CustomFluxMLPEmbedder + if isinstance(module, (MLPEmbedder, CustomFluxMLPEmbedder)): + use_sidecar_patching = True + if use_sidecar_patching: LayerPatcher._apply_model_layer_wrapper_patch( module_to_patch=module, @@ -174,9 +182,12 @@ def _apply_model_layer_patch( # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - for param_name, param_weight in patch.get_parameters( - dict(module_to_patch.named_parameters(recurse=False)), weight=patch_weight - ).items(): + # Manually create orig_parameters to bypass named_parameters() issue. + orig_parameters = {"weight": module_to_patch.weight} + if module_to_patch.bias is not None: + orig_parameters["bias"] = module_to_patch.bias + + for param_name, param_weight in patch.get_parameters(orig_parameters, weight=patch_weight).items(): param_key = module_to_patch_key + "." + param_name module_param = module_to_patch.get_parameter(param_name) @@ -250,7 +261,9 @@ def _set_submodule(parent_module: torch.nn.Module, module_name: str, submodule: @staticmethod def _get_submodule( - model: torch.nn.Module, layer_key: str, layer_key_is_flattened: bool + model: torch.nn.Module, + layer_key: str, + layer_key_is_flattened: bool, ) -> tuple[str, torch.nn.Module]: """Get the submodule corresponding to the given layer key. diff --git a/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py index 188d118cc4d..4fe64181184 100644 --- a/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py @@ -7,6 +7,7 @@ from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw +from invokeai.backend.util.logging import InvokeAILogger def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool: @@ -239,10 +240,19 @@ def add_qkv_lora_layer_if_present( ) # Final layer. + # Hyper FLUX LoRA support: patch norm_out.linear if present + # add_lora_layer_if_present("norm_out.linear", "norm_out.linear") + add_lora_layer_if_present("norm_out.linear", "final_layer.adaLN_modulation.1") + add_lora_layer_if_present("proj_out", "final_layer.linear") # Assert that all keys were processed. - assert len(grouped_state_dict) == 0 + if len(grouped_state_dict) > 0: + logger = InvokeAILogger.get_logger() + logger.warning( + f"The following unexpected LoRA layers were not loaded: {list(grouped_state_dict.keys())}." + " This is not necessarily a problem, but the LoRA may not be fully applied." + ) layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}