Skip to content

Commit e05d78f

Browse files
committed
qwen: Implement transformer block prefetching
1 parent b27d3d8 commit e05d78f

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

comfy/ldm/qwen_image/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,9 @@ def _forward(
385385
hidden_states, img_ids, orig_shape = self.process_img(x)
386386
num_embeds = hidden_states.shape[1]
387387

388+
prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks))
389+
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, None)
390+
388391
if ref_latents is not None:
389392
h = 0
390393
w = 0
@@ -434,6 +437,7 @@ def _forward(
434437
blocks_replace = patches_replace.get("dit", {})
435438

436439
for i, block in enumerate(self.transformer_blocks):
440+
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
437441
if ("double_block", i) in blocks_replace:
438442
def block_wrap(args):
439443
out = {}
@@ -465,6 +469,8 @@ def block_wrap(args):
465469
if add is not None:
466470
hidden_states[:, :add.shape[1]] += add
467471

472+
comfy.ops.prefetch_queue_pop(prefetch_queue, x.device, block)
473+
468474
hidden_states = self.norm_out(hidden_states, temb)
469475
hidden_states = self.proj_out(hidden_states)
470476

0 commit comments

Comments
 (0)