From eaaac2e49de5a32c1db6227d6bee5fde1f1a6fa9 Mon Sep 17 00:00:00 2001 From: newgrit1004 Date: Fri, 12 Jan 2024 22:37:07 +0900 Subject: [PATCH] Convert old LoRA format to the new format Signed-off-by: newgrit1004 --- demo/Diffusion/models.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/demo/Diffusion/models.py b/demo/Diffusion/models.py index ffeb6cfda..763e96d0a 100644 --- a/demo/Diffusion/models.py +++ b/demo/Diffusion/models.py @@ -232,9 +232,23 @@ def get_dicts(self, } else: - # Otherwise, we're dealing with the old format. - warn_message = "You have saved the LoRA weights using the old format. To convert LoRA weights to the new format, first load them in a dictionary and then create a new dictionary as follows: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." - print(warn_message) + # Otherwise, convert old LoRA format to the new format. + self.state_dict[path] = {f'unet.{module_name}': params for module_name, params in self.state_dict[path].items()} + keys = list(self.state_dict[path].keys()) + if all(key.startswith(('unet', 'text_encoder')) for key in keys): + keys = [k for k in keys if k.startswith(prefix)] + if keys: + print(f"Processing {prefix} LoRA: {path}") + state_dict[path] = {k.replace(f"{prefix}.", ""): v for k, v in self.state_dict[path].items() if k in keys} + + if path in self.network_alphas: + if self.network_alphas[path]: + alpha_keys = [k for k in self.network_alphas[path].keys() if k.startswith(prefix)] + network_alphas[path] = { + k.replace(f"{prefix}.", ""): v for k, v in self.network_alphas[path].items() if k in alpha_keys + } + else: + network_alphas[path] = None return state_dict, network_alphas