@@ -195,20 +195,50 @@ def block_wrap(args):
195195 img = self .final_layer (img , vec ) # (N, T, patch_size ** 2 * out_channels)
196196 return img
197197
198- def forward (self , x , timestep , context , y = None , guidance = None , control = None , transformer_options = {}, ** kwargs ):
198+ def process_img (self , x , index = 0 , h_offset = 0 , w_offset = 0 ):
199199 bs , c , h , w = x .shape
200200 patch_size = self .patch_size
201201 x = comfy .ldm .common_dit .pad_to_patch_size (x , (patch_size , patch_size ))
202202
203203 img = rearrange (x , "b c (h ph) (w pw) -> b (h w) (c ph pw)" , ph = patch_size , pw = patch_size )
204-
205204 h_len = ((h + (patch_size // 2 )) // patch_size )
206205 w_len = ((w + (patch_size // 2 )) // patch_size )
206+
207+ h_offset = ((h_offset + (patch_size // 2 )) // patch_size )
208+ w_offset = ((w_offset + (patch_size // 2 )) // patch_size )
209+
207210 img_ids = torch .zeros ((h_len , w_len , 3 ), device = x .device , dtype = x .dtype )
208- img_ids [:, :, 1 ] = img_ids [:, :, 1 ] + torch .linspace (0 , h_len - 1 , steps = h_len , device = x .device , dtype = x .dtype ).unsqueeze (1 )
209- img_ids [:, :, 2 ] = img_ids [:, :, 2 ] + torch .linspace (0 , w_len - 1 , steps = w_len , device = x .device , dtype = x .dtype ).unsqueeze (0 )
210- img_ids = repeat (img_ids , "h w c -> b (h w) c" , b = bs )
211+ img_ids [:, :, 0 ] = img_ids [:, :, 1 ] + index
212+ img_ids [:, :, 1 ] = img_ids [:, :, 1 ] + torch .linspace (h_offset , h_len - 1 + h_offset , steps = h_len , device = x .device , dtype = x .dtype ).unsqueeze (1 )
213+ img_ids [:, :, 2 ] = img_ids [:, :, 2 ] + torch .linspace (w_offset , w_len - 1 + w_offset , steps = w_len , device = x .device , dtype = x .dtype ).unsqueeze (0 )
214+ return img , repeat (img_ids , "h w c -> b (h w) c" , b = bs )
215+
216+ def forward (self , x , timestep , context , y = None , guidance = None , ref_latents = None , control = None , transformer_options = {}, ** kwargs ):
217+ bs , c , h_orig , w_orig = x .shape
218+ patch_size = self .patch_size
219+
220+ h_len = ((h_orig + (patch_size // 2 )) // patch_size )
221+ w_len = ((w_orig + (patch_size // 2 )) // patch_size )
222+ img , img_ids = self .process_img (x )
223+ img_tokens = img .shape [1 ]
224+ if ref_latents is not None :
225+ h = 0
226+ w = 0
227+ for ref in ref_latents :
228+ h_offset = 0
229+ w_offset = 0
230+ if ref .shape [- 2 ] + h > ref .shape [- 1 ] + w :
231+ w_offset = w
232+ else :
233+ h_offset = h
234+
235+ kontext , kontext_ids = self .process_img (ref , index = 1 , h_offset = h_offset , w_offset = w_offset )
236+ img = torch .cat ([img , kontext ], dim = 1 )
237+ img_ids = torch .cat ([img_ids , kontext_ids ], dim = 1 )
238+ h = max (h , ref .shape [- 2 ] + h_offset )
239+ w = max (w , ref .shape [- 1 ] + w_offset )
211240
212241 txt_ids = torch .zeros ((bs , context .shape [1 ], 3 ), device = x .device , dtype = x .dtype )
213242 out = self .forward_orig (img , img_ids , context , txt_ids , timestep , y , guidance , control , transformer_options , attn_mask = kwargs .get ("attention_mask" , None ))
214- return rearrange (out , "b (h w) (c ph pw) -> b c (h ph) (w pw)" , h = h_len , w = w_len , ph = 2 , pw = 2 )[:,:,:h ,:w ]
243+ out = out [:, :img_tokens ]
244+ return rearrange (out , "b (h w) (c ph pw) -> b c (h ph) (w pw)" , h = h_len , w = w_len , ph = 2 , pw = 2 )[:,:,:h_orig ,:w_orig ]
0 commit comments