File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed
Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -389,7 +389,10 @@ def forward_orig(
389389 attn_mask = None
390390
391391 blocks_replace = patches_replace .get ("dit" , {})
392+ transformer_options ["total_blocks" ] = len (self .double_blocks )
393+ transformer_options ["block_type" ] = "double"
392394 for i , block in enumerate (self .double_blocks ):
395+ transformer_options ["block_index" ] = i
393396 if ("double_block" , i ) in blocks_replace :
394397 def block_wrap (args ):
395398 out = {}
@@ -411,7 +414,10 @@ def block_wrap(args):
411414
412415 img = torch .cat ((img , txt ), 1 )
413416
417+ transformer_options ["total_blocks" ] = len (self .single_blocks )
418+ transformer_options ["block_type" ] = "single"
414419 for i , block in enumerate (self .single_blocks ):
420+ transformer_options ["block_index" ] = i
415421 if ("single_block" , i ) in blocks_replace :
416422 def block_wrap (args ):
417423 out = {}
You can’t perform that action at this time.
0 commit comments