Skip to content

Commit 4977f20

Browse files
P2 of qwen edit model. (Comfy-Org#9412)
* P2 of qwen edit model. * Typo. * Fix normal qwen. * Fix. * Make the TextEncodeQwenImageEdit also set the ref latent. If you don't want it to set the ref latent and want to use the ReferenceLatent node with your custom latent instead just disconnect the VAE.
1 parent bd2ab73 commit 4977f20

File tree

10 files changed

+565
-15
lines changed

10 files changed

+565
-15
lines changed

comfy/clip_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, config_dict, dtype, device, operations):
9797
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
9898
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
9999

100-
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
100+
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
101101
if embeds is not None:
102102
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
103103
else:

comfy/model_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,7 @@ def extra_conds_shapes(self, **kwargs):
13251325
class QwenImage(BaseModel):
13261326
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
13271327
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
1328+
self.memory_usage_factor_conds = ("ref_latents",)
13281329

13291330
def extra_conds(self, **kwargs):
13301331
out = super().extra_conds(**kwargs)
@@ -1342,3 +1343,10 @@ def extra_conds(self, **kwargs):
13421343
if ref_latents_method is not None:
13431344
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
13441345
return out
1346+
1347+
def extra_conds_shapes(self, **kwargs):
1348+
out = {}
1349+
ref_latents = kwargs.get("reference_latents", None)
1350+
if ref_latents is not None:
1351+
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
1352+
return out

comfy/sd1_clip.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,17 +204,19 @@ def process_tokens(self, tokens, device):
204204
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
205205
index = 0
206206
pad_extra = 0
207+
embeds_info = []
207208
for o in other_embeds:
208209
emb = o[1]
209210
if torch.is_tensor(emb):
210211
emb = {"type": "embedding", "data": emb}
211212

213+
extra = None
212214
emb_type = emb.get("type", None)
213215
if emb_type == "embedding":
214216
emb = emb.get("data", None)
215217
else:
216218
if hasattr(self.transformer, "preprocess_embed"):
217-
emb = self.transformer.preprocess_embed(emb, device=device)
219+
emb, extra = self.transformer.preprocess_embed(emb, device=device)
218220
else:
219221
emb = None
220222

@@ -229,6 +231,7 @@ def process_tokens(self, tokens, device):
229231
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
230232
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
231233
index += emb_shape - 1
234+
embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra})
232235
else:
233236
index += -1
234237
pad_extra += emb_shape
@@ -243,11 +246,11 @@ def process_tokens(self, tokens, device):
243246
attention_masks.append(attention_mask)
244247
num_tokens.append(sum(attention_mask))
245248

246-
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
249+
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
247250

248251
def forward(self, tokens):
249252
device = self.transformer.get_input_embeddings().weight.device
250-
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
253+
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
251254

252255
attention_mask_model = None
253256
if self.enable_attention_masks:
@@ -258,7 +261,7 @@ def forward(self, tokens):
258261
else:
259262
intermediate_output = self.layer_idx
260263

261-
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
264+
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info)
262265

263266
if self.layer == "last":
264267
z = outputs[0].float()

comfy/text_encoders/bert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __init__(self, config_dict, dtype, device, operations):
116116
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
117117
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
118118

119-
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
119+
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
120120
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
121121
mask = None
122122
if attention_mask is not None:

comfy/text_encoders/llama.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
import torch.nn as nn
33
from dataclasses import dataclass
44
from typing import Optional, Any
5+
import math
56

67
from comfy.ldm.modules.attention import optimized_attention_for_device
78
import comfy.model_management
89
import comfy.ldm.common_dit
910

1011
import comfy.model_management
12+
from . import qwen_vl
1113

1214
@dataclass
1315
class Llama2Config:
@@ -100,12 +102,10 @@ def rotate_half(x):
100102
return torch.cat((-x2, x1), dim=-1)
101103

102104

103-
def precompute_freqs_cis(head_dim, seq_len, theta, device=None):
105+
def precompute_freqs_cis(head_dim, position_ids, theta, device=None):
104106
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
105107
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
106108

107-
position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)
108-
109109
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
110110
position_ids_expanded = position_ids[:, None, :].float()
111111
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
@@ -277,7 +277,7 @@ def __init__(self, config, device=None, dtype=None, ops=None):
277277
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
278278
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
279279

280-
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
280+
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
281281
if embeds is not None:
282282
x = embeds
283283
else:
@@ -286,8 +286,11 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
286286
if self.normalize_in:
287287
x *= self.config.hidden_size ** 0.5
288288

289+
if position_ids is None:
290+
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
291+
289292
freqs_cis = precompute_freqs_cis(self.config.head_dim,
290-
x.shape[1],
293+
position_ids,
291294
self.config.rope_theta,
292295
device=x.device)
293296

@@ -372,8 +375,38 @@ def __init__(self, config_dict, dtype, device, operations):
372375
self.num_layers = config.num_hidden_layers
373376

374377
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
378+
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
375379
self.dtype = dtype
376380

381+
def preprocess_embed(self, embed, device):
382+
if embed["type"] == "image":
383+
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
384+
return self.visual(image.to(device, dtype=torch.float32), grid), grid
385+
return None, None
386+
387+
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
388+
grid = None
389+
for e in embeds_info:
390+
if e.get("type") == "image":
391+
grid = e.get("extra", None)
392+
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
393+
start = e.get("index")
394+
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
395+
end = e.get("size") + start
396+
len_max = int(grid.max()) // 2
397+
start_next = len_max + start
398+
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
399+
position_ids[0, start:end] = start
400+
max_d = int(grid[0][1]) // 2
401+
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
402+
max_d = int(grid[0][2]) // 2
403+
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
404+
405+
if grid is None:
406+
position_ids = None
407+
408+
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
409+
377410
class Gemma2_2B(BaseLlama, torch.nn.Module):
378411
def __init__(self, config_dict, dtype, device, operations):
379412
super().__init__()

comfy/text_encoders/qwen_image.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,27 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
1515
def __init__(self, embedding_directory=None, tokenizer_data={}):
1616
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
1717
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
18+
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image \\(color, shape, size, texture, objects, background\\), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
1819

19-
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
20+
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
2021
if llama_template is None:
21-
llama_text = self.llama_template.format(text)
22+
if len(images) > 0:
23+
llama_text = self.llama_template_images.format(text)
24+
else:
25+
llama_text = self.llama_template.format(text)
2226
else:
2327
llama_text = llama_template.format(text)
24-
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
28+
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
29+
key_name = next(iter(tokens))
30+
embed_count = 0
31+
qwen_tokens = tokens[key_name]
32+
for r in qwen_tokens:
33+
for i in range(len(r)):
34+
if r[i][0] == 151655:
35+
if len(images) > embed_count:
36+
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
37+
embed_count += 1
38+
return tokens
2539

2640

2741
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):

0 commit comments

Comments
 (0)