File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -385,6 +385,9 @@ def _forward(
385385 hidden_states , img_ids , orig_shape = self .process_img (x )
386386 num_embeds = hidden_states .shape [1 ]
387387
388+ prefetch_queue = comfy .ops .make_prefetch_queue (list (self .transformer_blocks ))
389+ comfy .ops .prefetch_queue_pop (prefetch_queue , x .device , None )
390+
388391 if ref_latents is not None :
389392 h = 0
390393 w = 0
@@ -434,6 +437,7 @@ def _forward(
434437 blocks_replace = patches_replace .get ("dit" , {})
435438
436439 for i , block in enumerate (self .transformer_blocks ):
440+ comfy .ops .prefetch_queue_pop (prefetch_queue , x .device , block )
437441 if ("double_block" , i ) in blocks_replace :
438442 def block_wrap (args ):
439443 out = {}
@@ -465,6 +469,8 @@ def block_wrap(args):
465469 if add is not None :
466470 hidden_states [:, :add .shape [1 ]] += add
467471
472+ comfy .ops .prefetch_queue_pop (prefetch_queue , x .device , block )
473+
468474 hidden_states = self .norm_out (hidden_states , temb )
469475 hidden_states = self .proj_out (hidden_states )
470476
You can’t perform that action at this time.
0 commit comments