Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion comfy/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,25 @@ def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)

all_intermediate = None
if intermediate_output is not None:
if intermediate_output < 0:
if intermediate_output == "all":
all_intermediate = []
intermediate_output = None
elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output

intermediate = None
for i, l in enumerate(self.layers):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())

if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)

return x, intermediate

class CLIPEmbeddings(torch.nn.Module):
Expand Down
18 changes: 15 additions & 3 deletions comfy/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ def __init__(self, json_config):
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
model_type = config.get("model_type", "clip_vision_model")
model_class = IMAGE_ENCODERS.get(model_type)
if model_type == "siglip_vision_model":
self.return_all_hidden_states = True
else:
self.return_all_hidden_states = False

self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
Expand All @@ -68,12 +74,18 @@ def get_sd(self):
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)

outputs = Output()
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
if self.return_all_hidden_states:
all_hs = out[1].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = all_hs[:, -2]
outputs["all_hidden_states"] = all_hs
else:
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())

outputs["mm_projected"] = out[3]
return outputs

Expand Down
13 changes: 11 additions & 2 deletions comfy/ldm/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def forward_orig(
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)

patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
Expand All @@ -117,9 +118,17 @@ def forward_orig(
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))

vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
txt = self.txt_in(txt)

if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
img = out["img"]
txt = out["txt"]
img_ids = out["img_ids"]
txt_ids = out["txt_ids"]

if img_ids is not None:
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
Expand Down Expand Up @@ -239,7 +248,7 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
index += 1
h_offset = 0
w_offset = 0
elif ref_latents_method == "uso":
elif ref_latents_method == "uxo":
index = 0
h_offset = h_len * patch_size + h
w_offset = w_len * patch_size + w
Expand Down
3 changes: 3 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,9 @@ def set_model_forward_timestep_embed_patch(self, patch):
def set_model_double_block_patch(self, patch):
self.set_model_patch(patch, "double_block")

def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")

def add_object_patch(self, name, obj):
self.object_patches[name] = obj

Expand Down
4 changes: 3 additions & 1 deletion comfy_extras/nodes_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class FluxKontextMultiReferenceLatentMethod:
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"reference_latents_method": (("offset", "index", "uso"), ),
"reference_latents_method": (("offset", "index", "uxo/uno"), ),
}}

RETURN_TYPES = ("CONDITIONING",)
Expand All @@ -115,6 +115,8 @@ def INPUT_TYPES(s):
CATEGORY = "advanced/conditioning/flux"

def append(self, conditioning, reference_latents_method):
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
reference_latents_method = "uxo"
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, )

Expand Down
186 changes: 183 additions & 3 deletions comfy_extras/nodes_model_patch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch import nn
import folder_paths
import comfy.utils
import comfy.ops
Expand Down Expand Up @@ -58,6 +59,136 @@ def control_block(self, img, controlnet_conditioning, block_id):
return self.controlnet_blocks[block_id](img, controlnet_conditioning)


class SigLIPMultiFeatProjModel(torch.nn.Module):
"""
SigLIP Multi-Feature Projection Model for processing style features from different layers
and projecting them into a unified hidden space.

Args:
siglip_token_nums (int): Number of SigLIP tokens, default 257
style_token_nums (int): Number of style tokens, default 256
siglip_token_dims (int): Dimension of SigLIP tokens, default 1536
hidden_size (int): Hidden layer size, default 3072
context_layer_norm (bool): Whether to use context layer normalization, default False
"""

def __init__(
self,
siglip_token_nums: int = 729,
style_token_nums: int = 64,
siglip_token_dims: int = 1152,
hidden_size: int = 3072,
context_layer_norm: bool = True,
device=None, dtype=None, operations=None
):
super().__init__()

# High-level feature processing (layer -2)
self.high_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.high_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.high_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)

# Mid-level feature processing (layer -11)
self.mid_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.mid_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.mid_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)

# Low-level feature processing (layer -20)
self.low_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.low_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.low_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)

def forward(self, siglip_outputs):
"""
Forward pass function

Args:
siglip_outputs: Output from SigLIP model, containing hidden_states

Returns:
torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size]
"""
dtype = next(self.high_embedding_linear.parameters()).dtype

# Process high-level features (layer -2)
high_embedding = self._process_layer_features(
siglip_outputs[2],
self.high_embedding_linear,
self.high_layer_norm,
self.high_projection,
dtype
)

# Process mid-level features (layer -11)
mid_embedding = self._process_layer_features(
siglip_outputs[1],
self.mid_embedding_linear,
self.mid_layer_norm,
self.mid_projection,
dtype
)

# Process low-level features (layer -20)
low_embedding = self._process_layer_features(
siglip_outputs[0],
self.low_embedding_linear,
self.low_layer_norm,
self.low_projection,
dtype
)

# Concatenate features from all layersmodel_patch
return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1)

def _process_layer_features(
self,
hidden_states: torch.Tensor,
embedding_linear: nn.Module,
layer_norm: nn.Module,
projection: nn.Module,
dtype: torch.dtype
) -> torch.Tensor:
"""
Helper function to process features from a single layer

Args:
hidden_states: Input hidden states [bs, seq_len, dim]
embedding_linear: Embedding linear layer
layer_norm: Layer normalization
projection: Projection layer
dtype: Target data type

Returns:
torch.Tensor: Processed features [bs, style_token_nums, hidden_size]
"""
# Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim]
embedding = embedding_linear(
hidden_states.to(dtype).transpose(1, 2)
).transpose(1, 2)

# Apply layer normalization
embedding = layer_norm(embedding)

# Project to target hidden space
embedding = projection(embedding)

return embedding

class ModelPatchLoader:
@classmethod
def INPUT_TYPES(s):
Expand All @@ -73,9 +204,14 @@ def load_model_patch(self, name):
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
dtype = comfy.utils.weight_dtype(sd)
# TODO: this node will work with more types of model patches
additional_in_dim = sd["img_in.weight"].shape[1] - 64
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)

if 'controlnet_blocks.0.y_rms.weight' in sd:
additional_in_dim = sd["img_in.weight"].shape[1] - 64
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
elif 'feature_embedder.mid_layer_norm.bias' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)

model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
return (model,)
Expand Down Expand Up @@ -157,7 +293,51 @@ def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=No
return (model_patched,)


class UsoStyleProjectorPatch:
def __init__(self, model_patch, encoded_image):
self.model_patch = model_patch
self.encoded_image = encoded_image

def __call__(self, kwargs):
txt_ids = kwargs.get("txt_ids")
txt = kwargs.get("txt")
siglip_embedding = self.model_patch.model(self.encoded_image.to(txt.dtype)).to(txt.dtype)
txt = torch.cat([siglip_embedding, txt], dim=1)
kwargs['txt'] = txt
kwargs['txt_ids'] = torch.cat([torch.zeros(siglip_embedding.shape[0], siglip_embedding.shape[1], 3, dtype=txt_ids.dtype, device=txt_ids.device), txt_ids], dim=1)
return kwargs

def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.encoded_image = self.encoded_image.to(device_or_dtype)
return self

def models(self):
return [self.model_patch]


class USOStyleReference:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_patch"
EXPERIMENTAL = True

CATEGORY = "advanced/model_patches/flux"

def apply_patch(self, model, model_patch, clip_vision_output):
encoded_image = torch.stack((clip_vision_output.all_hidden_states[:, -20], clip_vision_output.all_hidden_states[:, -11], clip_vision_output.penultimate_hidden_states))
model_patched = model.clone()
model_patched.set_model_post_input_patch(UsoStyleProjectorPatch(model_patch, encoded_image))
return (model_patched,)


NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
"USOStyleReference": USOStyleReference,
}
Loading