@@ -143,17 +143,26 @@ def apply_rotary_emb_qwen(
143143
144144
145145class QwenTimestepProjEmbeddings (nn .Module ):
146- def __init__ (self , embedding_dim ):
146+ def __init__ (self , embedding_dim , use_additional_t_cond = False ):
147147 super ().__init__ ()
148148
149149 self .time_proj = Timesteps (num_channels = 256 , flip_sin_to_cos = True , downscale_freq_shift = 0 , scale = 1000 )
150150 self .timestep_embedder = TimestepEmbedding (in_channels = 256 , time_embed_dim = embedding_dim )
151+ self .use_additional_t_cond = use_additional_t_cond
152+ if use_additional_t_cond :
153+ self .addition_t_embedding = nn .Embedding (2 , embedding_dim )
151154
152- def forward (self , timestep , hidden_states ):
155+ def forward (self , timestep , hidden_states , addition_t_cond = None ):
153156 timesteps_proj = self .time_proj (timestep )
154157 timesteps_emb = self .timestep_embedder (timesteps_proj .to (dtype = hidden_states .dtype )) # (N, D)
155158
156159 conditioning = timesteps_emb
160+ if self .use_additional_t_cond :
161+ if addition_t_cond is None :
162+ raise ValueError ("When additional_t_cond is True, addition_t_cond must be provided." )
163+ addition_t_emb = self .addition_t_embedding (addition_t_cond )
164+ addition_t_emb = addition_t_emb .to (dtype = hidden_states .dtype )
165+ conditioning = conditioning + addition_t_emb
157166
158167 return conditioning
159168
@@ -259,6 +268,120 @@ def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0
259268 return freqs .clone ().contiguous ()
260269
261270
271+ class QwenEmbedLayer3DRope (nn .Module ):
272+ def __init__ (self , theta : int , axes_dim : List [int ], scale_rope = False ):
273+ super ().__init__ ()
274+ self .theta = theta
275+ self .axes_dim = axes_dim
276+ pos_index = torch .arange (4096 )
277+ neg_index = torch .arange (4096 ).flip (0 ) * - 1 - 1
278+ self .pos_freqs = torch .cat (
279+ [
280+ self .rope_params (pos_index , self .axes_dim [0 ], self .theta ),
281+ self .rope_params (pos_index , self .axes_dim [1 ], self .theta ),
282+ self .rope_params (pos_index , self .axes_dim [2 ], self .theta ),
283+ ],
284+ dim = 1 ,
285+ )
286+ self .neg_freqs = torch .cat (
287+ [
288+ self .rope_params (neg_index , self .axes_dim [0 ], self .theta ),
289+ self .rope_params (neg_index , self .axes_dim [1 ], self .theta ),
290+ self .rope_params (neg_index , self .axes_dim [2 ], self .theta ),
291+ ],
292+ dim = 1 ,
293+ )
294+
295+ self .scale_rope = scale_rope
296+
297+ def rope_params (self , index , dim , theta = 10000 ):
298+ """
299+ Args:
300+ index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
301+ """
302+ assert dim % 2 == 0
303+ freqs = torch .outer (index , 1.0 / torch .pow (theta , torch .arange (0 , dim , 2 ).to (torch .float32 ).div (dim )))
304+ freqs = torch .polar (torch .ones_like (freqs ), freqs )
305+ return freqs
306+
307+ def forward (self , video_fhw , txt_seq_lens , device ):
308+ """
309+ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
310+ txt_length: [bs] a list of 1 integers representing the length of the text
311+ """
312+ if self .pos_freqs .device != device :
313+ self .pos_freqs = self .pos_freqs .to (device )
314+ self .neg_freqs = self .neg_freqs .to (device )
315+
316+ if isinstance (video_fhw , list ):
317+ video_fhw = video_fhw [0 ]
318+ if not isinstance (video_fhw , list ):
319+ video_fhw = [video_fhw ]
320+
321+ vid_freqs = []
322+ max_vid_index = 0
323+ layer_num = len (video_fhw ) - 1
324+ for idx , fhw in enumerate (video_fhw ):
325+ frame , height , width = fhw
326+ if idx != layer_num :
327+ video_freq = self ._compute_video_freqs (frame , height , width , idx )
328+ else :
329+ ### For the condition image, we set the layer index to -1
330+ video_freq = self ._compute_condition_freqs (frame , height , width )
331+ video_freq = video_freq .to (device )
332+ vid_freqs .append (video_freq )
333+
334+ if self .scale_rope :
335+ max_vid_index = max (height // 2 , width // 2 , max_vid_index )
336+ else :
337+ max_vid_index = max (height , width , max_vid_index )
338+
339+ max_vid_index = max (max_vid_index , layer_num )
340+ max_len = max (txt_seq_lens )
341+ txt_freqs = self .pos_freqs [max_vid_index : max_vid_index + max_len , ...]
342+ vid_freqs = torch .cat (vid_freqs , dim = 0 )
343+
344+ return vid_freqs , txt_freqs
345+
346+ @functools .lru_cache (maxsize = None )
347+ def _compute_video_freqs (self , frame , height , width , idx = 0 ):
348+ seq_lens = frame * height * width
349+ freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
350+ freqs_neg = self .neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
351+
352+ freqs_frame = freqs_pos [0 ][idx : idx + frame ].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
353+ if self .scale_rope :
354+ freqs_height = torch .cat ([freqs_neg [1 ][- (height - height // 2 ) :], freqs_pos [1 ][: height // 2 ]], dim = 0 )
355+ freqs_height = freqs_height .view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
356+ freqs_width = torch .cat ([freqs_neg [2 ][- (width - width // 2 ) :], freqs_pos [2 ][: width // 2 ]], dim = 0 )
357+ freqs_width = freqs_width .view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
358+ else :
359+ freqs_height = freqs_pos [1 ][:height ].view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
360+ freqs_width = freqs_pos [2 ][:width ].view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
361+
362+ freqs = torch .cat ([freqs_frame , freqs_height , freqs_width ], dim = - 1 ).reshape (seq_lens , - 1 )
363+ return freqs .clone ().contiguous ()
364+
365+ @functools .lru_cache (maxsize = None )
366+ def _compute_condition_freqs (self , frame , height , width ):
367+ seq_lens = frame * height * width
368+ freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
369+ freqs_neg = self .neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
370+
371+ freqs_frame = freqs_neg [0 ][- 1 :].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
372+ if self .scale_rope :
373+ freqs_height = torch .cat ([freqs_neg [1 ][- (height - height // 2 ) :], freqs_pos [1 ][: height // 2 ]], dim = 0 )
374+ freqs_height = freqs_height .view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
375+ freqs_width = torch .cat ([freqs_neg [2 ][- (width - width // 2 ) :], freqs_pos [2 ][: width // 2 ]], dim = 0 )
376+ freqs_width = freqs_width .view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
377+ else :
378+ freqs_height = freqs_pos [1 ][:height ].view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
379+ freqs_width = freqs_pos [2 ][:width ].view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
380+
381+ freqs = torch .cat ([freqs_frame , freqs_height , freqs_width ], dim = - 1 ).reshape (seq_lens , - 1 )
382+ return freqs .clone ().contiguous ()
383+
384+
262385class QwenDoubleStreamAttnProcessor2_0 :
263386 """
264387 Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
@@ -578,14 +701,21 @@ def __init__(
578701 guidance_embeds : bool = False , # TODO: this should probably be removed
579702 axes_dims_rope : Tuple [int , int , int ] = (16 , 56 , 56 ),
580703 zero_cond_t : bool = False ,
704+ use_additional_t_cond : bool = False ,
705+ use_layer3d_rope : bool = False ,
581706 ):
582707 super ().__init__ ()
583708 self .out_channels = out_channels or in_channels
584709 self .inner_dim = num_attention_heads * attention_head_dim
585710
586- self .pos_embed = QwenEmbedRope (theta = 10000 , axes_dim = list (axes_dims_rope ), scale_rope = True )
711+ if not use_layer3d_rope :
712+ self .pos_embed = QwenEmbedRope (theta = 10000 , axes_dim = list (axes_dims_rope ), scale_rope = True )
713+ else :
714+ self .pos_embed = QwenEmbedLayer3DRope (theta = 10000 , axes_dim = list (axes_dims_rope ), scale_rope = True )
587715
588- self .time_text_embed = QwenTimestepProjEmbeddings (embedding_dim = self .inner_dim )
716+ self .time_text_embed = QwenTimestepProjEmbeddings (
717+ embedding_dim = self .inner_dim , use_additional_t_cond = use_additional_t_cond
718+ )
589719
590720 self .txt_norm = RMSNorm (joint_attention_dim , eps = 1e-6 )
591721
@@ -621,6 +751,7 @@ def forward(
621751 guidance : torch .Tensor = None , # TODO: this should probably be removed
622752 attention_kwargs : Optional [Dict [str , Any ]] = None ,
623753 controlnet_block_samples = None ,
754+ additional_t_cond = None ,
624755 return_dict : bool = True ,
625756 ) -> Union [torch .Tensor , Transformer2DModelOutput ]:
626757 """
@@ -683,9 +814,9 @@ def forward(
683814 guidance = guidance .to (hidden_states .dtype ) * 1000
684815
685816 temb = (
686- self .time_text_embed (timestep , hidden_states )
817+ self .time_text_embed (timestep , hidden_states , additional_t_cond )
687818 if guidance is None
688- else self .time_text_embed (timestep , guidance , hidden_states )
819+ else self .time_text_embed (timestep , guidance , hidden_states , additional_t_cond )
689820 )
690821
691822 image_rotary_emb = self .pos_embed (img_shapes , txt_seq_lens , device = hidden_states .device )
0 commit comments