@@ -254,18 +254,17 @@ def block_wrap(args):
254254
255255 def forward (self , x , timestep , context , guidance , control = None , transformer_options = {}, ** kwargs ):
256256 bs , c , h , w = x .shape
257- patch_size = 2
258- x = comfy .ldm .common_dit .pad_to_patch_size (x , (patch_size , patch_size ))
257+ x = comfy .ldm .common_dit .pad_to_patch_size (x , (self .patch_size , self .patch_size ))
259258
260- img = rearrange (x , "b c (h ph) (w pw) -> b (h w) (c ph pw)" , ph = patch_size , pw = patch_size )
259+ img = rearrange (x , "b c (h ph) (w pw) -> b (h w) (c ph pw)" , ph = self . patch_size , pw = self . patch_size )
261260
262- h_len = ((h + (patch_size // 2 )) // patch_size )
263- w_len = ((w + (patch_size // 2 )) // patch_size )
261+ h_len = ((h + (self . patch_size // 2 )) // self . patch_size )
262+ w_len = ((w + (self . patch_size // 2 )) // self . patch_size )
264263 img_ids = torch .zeros ((h_len , w_len , 3 ), device = x .device , dtype = x .dtype )
265264 img_ids [:, :, 1 ] = img_ids [:, :, 1 ] + torch .linspace (0 , h_len - 1 , steps = h_len , device = x .device , dtype = x .dtype ).unsqueeze (1 )
266265 img_ids [:, :, 2 ] = img_ids [:, :, 2 ] + torch .linspace (0 , w_len - 1 , steps = w_len , device = x .device , dtype = x .dtype ).unsqueeze (0 )
267266 img_ids = repeat (img_ids , "h w c -> b (h w) c" , b = bs )
268267
269268 txt_ids = torch .zeros ((bs , context .shape [1 ], 3 ), device = x .device , dtype = x .dtype )
270269 out = self .forward_orig (img , img_ids , context , txt_ids , timestep , guidance , control , transformer_options , attn_mask = kwargs .get ("attention_mask" , None ))
271- return rearrange (out , "b (h w) (c ph pw) -> b c (h ph) (w pw)" , h = h_len , w = w_len , ph = 2 , pw = 2 )[:,:,:h ,:w ]
270+ return rearrange (out , "b (h w) (c ph pw) -> b c (h ph) (w pw)" , h = h_len , w = w_len , ph = self . patch_size , pw = self . patch_size )[:,:,:h ,:w ]
0 commit comments