@@ -247,6 +247,60 @@ def forward(self, c, x, **kwargs):
247247 return c_skip , c
248248
249249
250+ class WanCamAdapter (nn .Module ):
251+ def __init__ (self , in_dim , out_dim , kernel_size , stride , num_residual_blocks = 1 , operation_settings = {}):
252+ super (WanCamAdapter , self ).__init__ ()
253+
254+ # Pixel Unshuffle: reduce spatial dimensions by a factor of 8
255+ self .pixel_unshuffle = nn .PixelUnshuffle (downscale_factor = 8 )
256+
257+ # Convolution: reduce spatial dimensions by a factor
258+ # of 2 (without overlap)
259+ self .conv = operation_settings .get ("operations" ).Conv2d (in_dim * 64 , out_dim , kernel_size = kernel_size , stride = stride , padding = 0 , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
260+
261+ # Residual blocks for feature extraction
262+ self .residual_blocks = nn .Sequential (
263+ * [WanCamResidualBlock (out_dim , operation_settings = operation_settings ) for _ in range (num_residual_blocks )]
264+ )
265+
266+ def forward (self , x ):
267+ # Reshape to merge the frame dimension into batch
268+ bs , c , f , h , w = x .size ()
269+ x = x .permute (0 , 2 , 1 , 3 , 4 ).contiguous ().view (bs * f , c , h , w )
270+
271+ # Pixel Unshuffle operation
272+ x_unshuffled = self .pixel_unshuffle (x )
273+
274+ # Convolution operation
275+ x_conv = self .conv (x_unshuffled )
276+
277+ # Feature extraction with residual blocks
278+ out = self .residual_blocks (x_conv )
279+
280+ # Reshape to restore original bf dimension
281+ out = out .view (bs , f , out .size (1 ), out .size (2 ), out .size (3 ))
282+
283+ # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
284+ out = out .permute (0 , 2 , 1 , 3 , 4 )
285+
286+ return out
287+
288+
289+ class WanCamResidualBlock (nn .Module ):
290+ def __init__ (self , dim , operation_settings = {}):
291+ super (WanCamResidualBlock , self ).__init__ ()
292+ self .conv1 = operation_settings .get ("operations" ).Conv2d (dim , dim , kernel_size = 3 , padding = 1 , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
293+ self .relu = nn .ReLU (inplace = True )
294+ self .conv2 = operation_settings .get ("operations" ).Conv2d (dim , dim , kernel_size = 3 , padding = 1 , device = operation_settings .get ("device" ), dtype = operation_settings .get ("dtype" ))
295+
296+ def forward (self , x ):
297+ residual = x
298+ out = self .relu (self .conv1 (x ))
299+ out = self .conv2 (out )
300+ out += residual
301+ return out
302+
303+
250304class Head (nn .Module ):
251305
252306 def __init__ (self , dim , out_dim , patch_size , eps = 1e-6 , operation_settings = {}):
@@ -637,3 +691,92 @@ def block_wrap(args):
637691 # unpatchify
638692 x = self .unpatchify (x , grid_sizes )
639693 return x
694+
695+ class CameraWanModel (WanModel ):
696+ r"""
697+ Wan diffusion backbone supporting both text-to-video and image-to-video.
698+ """
699+
700+ def __init__ (self ,
701+ model_type = 'camera' ,
702+ patch_size = (1 , 2 , 2 ),
703+ text_len = 512 ,
704+ in_dim = 16 ,
705+ dim = 2048 ,
706+ ffn_dim = 8192 ,
707+ freq_dim = 256 ,
708+ text_dim = 4096 ,
709+ out_dim = 16 ,
710+ num_heads = 16 ,
711+ num_layers = 32 ,
712+ window_size = (- 1 , - 1 ),
713+ qk_norm = True ,
714+ cross_attn_norm = True ,
715+ eps = 1e-6 ,
716+ flf_pos_embed_token_number = None ,
717+ image_model = None ,
718+ in_dim_control_adapter = 24 ,
719+ device = None ,
720+ dtype = None ,
721+ operations = None ,
722+ ):
723+
724+ super ().__init__ (model_type = 'i2v' , patch_size = patch_size , text_len = text_len , in_dim = in_dim , dim = dim , ffn_dim = ffn_dim , freq_dim = freq_dim , text_dim = text_dim , out_dim = out_dim , num_heads = num_heads , num_layers = num_layers , window_size = window_size , qk_norm = qk_norm , cross_attn_norm = cross_attn_norm , eps = eps , flf_pos_embed_token_number = flf_pos_embed_token_number , image_model = image_model , device = device , dtype = dtype , operations = operations )
725+ operation_settings = {"operations" : operations , "device" : device , "dtype" : dtype }
726+
727+ self .control_adapter = WanCamAdapter (in_dim_control_adapter , dim , kernel_size = patch_size [1 :], stride = patch_size [1 :], operation_settings = operation_settings )
728+
729+
730+ def forward_orig (
731+ self ,
732+ x ,
733+ t ,
734+ context ,
735+ clip_fea = None ,
736+ freqs = None ,
737+ camera_conditions = None ,
738+ transformer_options = {},
739+ ** kwargs ,
740+ ):
741+ # embeddings
742+ x = self .patch_embedding (x .float ()).to (x .dtype )
743+ if self .control_adapter is not None and camera_conditions is not None :
744+ x_camera = self .control_adapter (camera_conditions ).to (x .dtype )
745+ x = x + x_camera
746+ grid_sizes = x .shape [2 :]
747+ x = x .flatten (2 ).transpose (1 , 2 )
748+
749+ # time embeddings
750+ e = self .time_embedding (
751+ sinusoidal_embedding_1d (self .freq_dim , t ).to (dtype = x [0 ].dtype ))
752+ e0 = self .time_projection (e ).unflatten (1 , (6 , self .dim ))
753+
754+ # context
755+ context = self .text_embedding (context )
756+
757+ context_img_len = None
758+ if clip_fea is not None :
759+ if self .img_emb is not None :
760+ context_clip = self .img_emb (clip_fea ) # bs x 257 x dim
761+ context = torch .concat ([context_clip , context ], dim = 1 )
762+ context_img_len = clip_fea .shape [- 2 ]
763+
764+ patches_replace = transformer_options .get ("patches_replace" , {})
765+ blocks_replace = patches_replace .get ("dit" , {})
766+ for i , block in enumerate (self .blocks ):
767+ if ("double_block" , i ) in blocks_replace :
768+ def block_wrap (args ):
769+ out = {}
770+ out ["img" ] = block (args ["img" ], context = args ["txt" ], e = args ["vec" ], freqs = args ["pe" ], context_img_len = context_img_len )
771+ return out
772+ out = blocks_replace [("double_block" , i )]({"img" : x , "txt" : context , "vec" : e0 , "pe" : freqs }, {"original_block" : block_wrap })
773+ x = out ["img" ]
774+ else :
775+ x = block (x , e = e0 , freqs = freqs , context = context , context_img_len = context_img_len )
776+
777+ # head
778+ x = self .head (x , e )
779+
780+ # unpatchify
781+ x = self .unpatchify (x , grid_sizes )
782+ return x
0 commit comments