@@ -347,7 +347,7 @@ def process_img(self, x, index=0, h_offset=0, w_offset=0):
347347 h_offset = ((h_offset + (patch_size // 2 )) // patch_size )
348348 w_offset = ((w_offset + (patch_size // 2 )) // patch_size )
349349
350- img_ids = torch .zeros ((h_len , w_len , 3 ), device = x .device , dtype = x . dtype )
350+ img_ids = torch .zeros ((h_len , w_len , 3 ), device = x .device )
351351 img_ids [:, :, 0 ] = img_ids [:, :, 1 ] + index
352352 img_ids [:, :, 1 ] = img_ids [:, :, 1 ] + torch .linspace (h_offset , h_len - 1 + h_offset , steps = h_len , device = x .device , dtype = x .dtype ).unsqueeze (1 ) - (h_len // 2 )
353353 img_ids [:, :, 2 ] = img_ids [:, :, 2 ] + torch .linspace (w_offset , w_len - 1 + w_offset , steps = w_len , device = x .device , dtype = x .dtype ).unsqueeze (0 ) - (w_len // 2 )
@@ -397,9 +397,10 @@ def forward(
397397 img_ids = torch .cat ([img_ids , kontext_ids ], dim = 1 )
398398
399399 txt_start = round (max (((x .shape [- 1 ] + (self .patch_size // 2 )) // self .patch_size ) // 2 , ((x .shape [- 2 ] + (self .patch_size // 2 )) // self .patch_size ) // 2 ))
400- txt_ids = torch .linspace (txt_start , txt_start + context .shape [1 ], steps = context . shape [ 1 ], device = x .device , dtype = x . dtype ).reshape (1 , - 1 , 1 ).repeat (x .shape [0 ], 1 , 3 )
400+ txt_ids = torch .arange (txt_start , txt_start + context .shape [1 ], device = x .device ).reshape (1 , - 1 , 1 ).repeat (x .shape [0 ], 1 , 3 )
401401 ids = torch .cat ((txt_ids , img_ids ), dim = 1 )
402402 image_rotary_emb = self .pe_embedder (ids ).squeeze (1 ).unsqueeze (2 ).to (x .dtype )
403+ del ids , txt_ids , img_ids
403404
404405 hidden_states = self .img_in (hidden_states )
405406 encoder_hidden_states = self .txt_norm (encoder_hidden_states )
0 commit comments