Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions app/user_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,17 @@ async def post_userdata(request):
if not overwrite and os.path.exists(path):
return web.Response(status=409, text="File already exists")

body = await request.read()
try:
body = await request.read()

with open(path, "wb") as f:
f.write(body)
with open(path, "wb") as f:
f.write(body)
except OSError as e:
logging.warning(f"Error saving file '{path}': {e}")
return web.Response(
status=400,
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
)

user_path = self.get_request_user_filepath(request, None)
if full_info:
Expand Down
7 changes: 7 additions & 0 deletions comfy/ldm/qwen_image/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def forward(
)

patches_replace = transformer_options.get("patches_replace", {})
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})

for i, block in enumerate(self.transformer_blocks):
Expand All @@ -436,6 +437,12 @@ def block_wrap(args):
image_rotary_emb=image_rotary_emb,
)

if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]

hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

Expand Down
8 changes: 7 additions & 1 deletion comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
else:
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())

models = set(models)
models_temp = set()
for m in models:
models_temp.add(m)
for mm in m.model_patches_models():
models_temp.add(mm)

models = models_temp

models_to_load = []

Expand Down
27 changes: 27 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@ def set_model_emb_patch(self, patch):
def set_model_forward_timestep_embed_patch(self, patch):
self.set_model_patch(patch, "forward_timestep_embed_patch")

def set_model_double_block_patch(self, patch):
self.set_model_patch(patch, "double_block")

def add_object_patch(self, name, obj):
self.object_patches[name] = obj

Expand Down Expand Up @@ -486,6 +489,30 @@ def model_patches_to(self, device):
if hasattr(wrap_func, "to"):
self.model_options["model_function_wrapper"] = wrap_func.to(device)

def model_patches_models(self):
to = self.model_options["transformer_options"]
models = []
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "models"):
models += patch_list[i].models()
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "models"):
models += patch_list[k].models()
if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, "models"):
models += wrap_func.models()

return models

def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
Expand Down
4 changes: 4 additions & 0 deletions comfy_api/latest/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,10 @@ class SEGS(ComfyTypeIO):
class AnyType(ComfyTypeIO):
Type = Any

@comfytype(io_type="MODEL_PATCH")
class MODEL_PATCH(ComfyTypeIO):
Type = Any

@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
Type = Any
Expand Down
138 changes: 138 additions & 0 deletions comfy_extras/nodes_model_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch
import folder_paths
import comfy.utils
import comfy.ops
import comfy.model_management
import comfy.ldm.common_dit
import comfy.latent_formats


class BlockWiseControlBlock(torch.nn.Module):
# [linear, gelu, linear]
def __init__(self, dim: int = 3072, device=None, dtype=None, operations=None):
super().__init__()
self.x_rms = operations.RMSNorm(dim, eps=1e-6)
self.y_rms = operations.RMSNorm(dim, eps=1e-6)
self.input_proj = operations.Linear(dim, dim)
self.act = torch.nn.GELU()
self.output_proj = operations.Linear(dim, dim)

def forward(self, x, y):
x, y = self.x_rms(x), self.y_rms(y)
x = self.input_proj(x + y)
x = self.act(x)
x = self.output_proj(x)
return x


class QwenImageBlockWiseControlNet(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
in_dim: int = 64,
additional_in_dim: int = 0,
dim: int = 3072,
device=None, dtype=None, operations=None
):
super().__init__()
self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype)
self.controlnet_blocks = torch.nn.ModuleList(
[
BlockWiseControlBlock(dim, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
]
)

def process_input_latent_image(self, latent_image):
latent_image = comfy.latent_formats.Wan21().process_in(latent_image)
patch_size = 2
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
return self.img_in(hidden_states)

def control_block(self, img, controlnet_conditioning, block_id):
return self.controlnet_blocks[block_id](img, controlnet_conditioning)


class ModelPatchLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ),
}}
RETURN_TYPES = ("MODEL_PATCH",)
FUNCTION = "load_model_patch"
EXPERIMENTAL = True

CATEGORY = "advanced/loaders"

def load_model_patch(self, name):
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
dtype = comfy.utils.weight_dtype(sd)
# TODO: this node will work with more types of model patches
model = QwenImageBlockWiseControlNet(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
return (model,)


class DiffSynthCnetPatch:
def __init__(self, model_patch, vae, image, strength):
self.encoded_image = model_patch.model.process_input_latent_image(vae.encode(image))
self.model_patch = model_patch
self.vae = vae
self.image = image
self.strength = strength

def __call__(self, kwargs):
x = kwargs.get("x")
img = kwargs.get("img")
block_index = kwargs.get("block_index")
if self.encoded_image is None or self.encoded_image.shape[1:] != img.shape[1:]:
spacial_compression = self.vae.spacial_compression_encode()
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.encoded_image = self.model_patch.model.process_input_latent_image(self.vae.encode(image_scaled.movedim(1, -1)))
comfy.model_management.load_models_gpu(loaded_models)

img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength)
kwargs['img'] = img
return kwargs

def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.encoded_image = self.encoded_image.to(device_or_dtype)
return self

def models(self):
return [self.model_patch]

class QwenImageDiffsynthControlnet:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"vae": ("VAE",),
"image": ("IMAGE",),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "diffsynth_controlnet"
EXPERIMENTAL = True

CATEGORY = "advanced/loaders/qwen"

def diffsynth_controlnet(self, model, model_patch, vae, image, strength):
model_patched = model.clone()
image = image[:, :, :, :3]
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength))
return (model_patched,)


NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
}
2 changes: 2 additions & 0 deletions folder_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@

folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})

folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patches")], supported_pt_extensions)

output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input")
Expand Down
Empty file.
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2322,6 +2322,7 @@ async def init_builtin_extra_nodes():
"nodes_tcfg.py",
"nodes_context_windows.py",
"nodes_qwen.py",
"nodes_model_patch.py"
]

import_failed = []
Expand Down
Loading