Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion comfy/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self, config_dict, dtype, device, operations):
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)

def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
if embeds is not None:
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
else:
Expand Down
8 changes: 8 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,7 @@ def extra_conds_shapes(self, **kwargs):
class QwenImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
self.memory_usage_factor_conds = ("ref_latents",)

def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
Expand All @@ -1342,3 +1343,10 @@ def extra_conds(self, **kwargs):
if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out

def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
11 changes: 7 additions & 4 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,19 @@ def process_tokens(self, tokens, device):
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
index = 0
pad_extra = 0
embeds_info = []
for o in other_embeds:
emb = o[1]
if torch.is_tensor(emb):
emb = {"type": "embedding", "data": emb}

extra = None
emb_type = emb.get("type", None)
if emb_type == "embedding":
emb = emb.get("data", None)
else:
if hasattr(self.transformer, "preprocess_embed"):
emb = self.transformer.preprocess_embed(emb, device=device)
emb, extra = self.transformer.preprocess_embed(emb, device=device)
else:
emb = None

Expand All @@ -229,6 +231,7 @@ def process_tokens(self, tokens, device):
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
index += emb_shape - 1
embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra})
else:
index += -1
pad_extra += emb_shape
Expand All @@ -243,11 +246,11 @@ def process_tokens(self, tokens, device):
attention_masks.append(attention_mask)
num_tokens.append(sum(attention_mask))

return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info

def forward(self, tokens):
device = self.transformer.get_input_embeddings().weight.device
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)

attention_mask_model = None
if self.enable_attention_masks:
Expand All @@ -258,7 +261,7 @@ def forward(self, tokens):
else:
intermediate_output = self.layer_idx

outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info)

if self.layer == "last":
z = outputs[0].float()
Expand Down
2 changes: 1 addition & 1 deletion comfy/text_encoders/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self, config_dict, dtype, device, operations):
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)

def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
mask = None
if attention_mask is not None:
Expand Down
43 changes: 38 additions & 5 deletions comfy/text_encoders/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any
import math

from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ldm.common_dit

import comfy.model_management
from . import qwen_vl

@dataclass
class Llama2Config:
Expand Down Expand Up @@ -100,12 +102,10 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def precompute_freqs_cis(head_dim, seq_len, theta, device=None):
def precompute_freqs_cis(head_dim, position_ids, theta, device=None):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))

position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)

inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
Expand Down Expand Up @@ -277,7 +277,7 @@ def __init__(self, config, device=None, dtype=None, ops=None):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)

def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
if embeds is not None:
x = embeds
else:
Expand All @@ -286,8 +286,11 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
if self.normalize_in:
x *= self.config.hidden_size ** 0.5

if position_ids is None:
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)

freqs_cis = precompute_freqs_cis(self.config.head_dim,
x.shape[1],
position_ids,
self.config.rope_theta,
device=x.device)

Expand Down Expand Up @@ -372,8 +375,38 @@ def __init__(self, config_dict, dtype, device, operations):
self.num_layers = config.num_hidden_layers

self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
return self.visual(image.to(device, dtype=torch.float32), grid), grid
return None, None

def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
grid = None
for e in embeds_info:
if e.get("type") == "image":
grid = e.get("extra", None)
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
start = e.get("index")
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
position_ids[0, start:end] = start
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]

if grid is None:
position_ids = None

return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)

class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
Expand Down
20 changes: 17 additions & 3 deletions comfy/text_encoders/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,27 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image \\(color, shape, size, texture, objects, background\\), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"

def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
key_name = next(iter(tokens))
embed_count = 0
qwen_tokens = tokens[key_name]
for r in qwen_tokens:
for i in range(len(r)):
if r[i][0] == 151655:
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return tokens


class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
Expand Down
Loading
Loading