Skip to content

Commit ed43784

Browse files
WIP Qwen edit model: The diffusion model part. (Comfy-Org#9383)
1 parent 0f2b852 commit ed43784

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

comfy/ldm/qwen_image/model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def forward(
360360
context,
361361
attention_mask=None,
362362
guidance: torch.Tensor = None,
363+
ref_latents=None,
363364
transformer_options={},
364365
**kwargs
365366
):
@@ -370,6 +371,31 @@ def forward(
370371
hidden_states, img_ids, orig_shape = self.process_img(x)
371372
num_embeds = hidden_states.shape[1]
372373

374+
if ref_latents is not None:
375+
h = 0
376+
w = 0
377+
index = 0
378+
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
379+
for ref in ref_latents:
380+
if index_ref_method:
381+
index += 1
382+
h_offset = 0
383+
w_offset = 0
384+
else:
385+
index = 1
386+
h_offset = 0
387+
w_offset = 0
388+
if ref.shape[-2] + h > ref.shape[-1] + w:
389+
w_offset = w
390+
else:
391+
h_offset = h
392+
h = max(h, ref.shape[-2] + h_offset)
393+
w = max(w, ref.shape[-1] + w_offset)
394+
395+
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
396+
hidden_states = torch.cat([hidden_states, kontext], dim=1)
397+
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
398+
373399
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size)))
374400
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
375401
ids = torch.cat((txt_ids, img_ids), dim=1)

comfy/model_base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,4 +1331,14 @@ def extra_conds(self, **kwargs):
13311331
cross_attn = kwargs.get("cross_attn", None)
13321332
if cross_attn is not None:
13331333
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
1334+
ref_latents = kwargs.get("reference_latents", None)
1335+
if ref_latents is not None:
1336+
latents = []
1337+
for lat in ref_latents:
1338+
latents.append(self.process_latent_in(lat))
1339+
out['ref_latents'] = comfy.conds.CONDList(latents)
1340+
1341+
ref_latents_method = kwargs.get("reference_latents_method", None)
1342+
if ref_latents_method is not None:
1343+
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
13341344
return out

0 commit comments

Comments
 (0)