@@ -356,6 +356,7 @@ def forward(
356356 context ,
357357 attention_mask = None ,
358358 guidance : torch .Tensor = None ,
359+ transformer_options = {},
359360 ** kwargs
360361 ):
361362 timestep = timesteps
@@ -383,14 +384,26 @@ def forward(
383384 else self .time_text_embed (timestep , guidance , hidden_states )
384385 )
385386
386- for block in self .transformer_blocks :
387- encoder_hidden_states , hidden_states = block (
388- hidden_states = hidden_states ,
389- encoder_hidden_states = encoder_hidden_states ,
390- encoder_hidden_states_mask = encoder_hidden_states_mask ,
391- temb = temb ,
392- image_rotary_emb = image_rotary_emb ,
393- )
387+ patches_replace = transformer_options .get ("patches_replace" , {})
388+ blocks_replace = patches_replace .get ("dit" , {})
389+
390+ for i , block in enumerate (self .transformer_blocks ):
391+ if ("double_block" , i ) in blocks_replace :
392+ def block_wrap (args ):
393+ out = {}
394+ out ["txt" ], out ["img" ] = block (hidden_states = args ["img" ], encoder_hidden_states = args ["txt" ], encoder_hidden_states_mask = encoder_hidden_states_mask , temb = args ["vec" ], image_rotary_emb = args ["pe" ])
395+ return out
396+ out = blocks_replace [("double_block" , i )]({"img" : hidden_states , "txt" : encoder_hidden_states , "vec" : temb , "pe" : image_rotary_emb }, {"original_block" : block_wrap })
397+ hidden_states = out ["img" ]
398+ encoder_hidden_states = out ["txt" ]
399+ else :
400+ encoder_hidden_states , hidden_states = block (
401+ hidden_states = hidden_states ,
402+ encoder_hidden_states = encoder_hidden_states ,
403+ encoder_hidden_states_mask = encoder_hidden_states_mask ,
404+ temb = temb ,
405+ image_rotary_emb = image_rotary_emb ,
406+ )
394407
395408 hidden_states = self .norm_out (hidden_states , temb )
396409 hidden_states = self .proj_out (hidden_states )
0 commit comments