@@ -360,6 +360,7 @@ def forward(
360360 context ,
361361 attention_mask = None ,
362362 guidance : torch .Tensor = None ,
363+ ref_latents = None ,
363364 transformer_options = {},
364365 ** kwargs
365366 ):
@@ -370,6 +371,31 @@ def forward(
370371 hidden_states , img_ids , orig_shape = self .process_img (x )
371372 num_embeds = hidden_states .shape [1 ]
372373
374+ if ref_latents is not None :
375+ h = 0
376+ w = 0
377+ index = 0
378+ index_ref_method = kwargs .get ("ref_latents_method" , "index" ) == "index"
379+ for ref in ref_latents :
380+ if index_ref_method :
381+ index += 1
382+ h_offset = 0
383+ w_offset = 0
384+ else :
385+ index = 1
386+ h_offset = 0
387+ w_offset = 0
388+ if ref .shape [- 2 ] + h > ref .shape [- 1 ] + w :
389+ w_offset = w
390+ else :
391+ h_offset = h
392+ h = max (h , ref .shape [- 2 ] + h_offset )
393+ w = max (w , ref .shape [- 1 ] + w_offset )
394+
395+ kontext , kontext_ids , _ = self .process_img (ref , index = index , h_offset = h_offset , w_offset = w_offset )
396+ hidden_states = torch .cat ([hidden_states , kontext ], dim = 1 )
397+ img_ids = torch .cat ([img_ids , kontext_ids ], dim = 1 )
398+
373399 txt_start = round (max (((x .shape [- 1 ] + (self .patch_size // 2 )) // self .patch_size ), ((x .shape [- 2 ] + (self .patch_size // 2 )) // self .patch_size )))
374400 txt_ids = torch .linspace (txt_start , txt_start + context .shape [1 ], steps = context .shape [1 ], device = x .device , dtype = x .dtype ).reshape (1 , - 1 , 1 ).repeat (x .shape [0 ], 1 , 3 )
375401 ids = torch .cat ((txt_ids , img_ids ), dim = 1 )
0 commit comments