File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed
Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -179,7 +179,10 @@ def forward_orig(
179179 pe = self .pe_embedder (ids )
180180
181181 blocks_replace = patches_replace .get ("dit" , {})
182+ transformer_options ["total_blocks" ] = len (self .double_blocks )
183+ transformer_options ["block_type" ] = "double"
182184 for i , block in enumerate (self .double_blocks ):
185+ transformer_options ["block_index" ] = i
183186 if i not in self .skip_mmdit :
184187 double_mod = (
185188 self .get_modulations (mod_vectors , "double_img" , idx = i ),
@@ -222,7 +225,10 @@ def block_wrap(args):
222225
223226 img = torch .cat ((txt , img ), 1 )
224227
228+ transformer_options ["total_blocks" ] = len (self .single_blocks )
229+ transformer_options ["block_type" ] = "single"
225230 for i , block in enumerate (self .single_blocks ):
231+ transformer_options ["block_index" ] = i
226232 if i not in self .skip_dit :
227233 single_mod = self .get_modulations (mod_vectors , "single" , idx = i )
228234 if ("single_block" , i ) in blocks_replace :
You can’t perform that action at this time.
0 commit comments