33import comfy .patcher_extension
44import comfy .ldm .modules .attention
55import comfy .ldm .common_dit
6- from einops import rearrange
76import math
87from typing import Dict , Optional , Tuple
98
109from .symmetric_patchifier import SymmetricPatchifier , latent_to_pixel_coords
11-
10+ from comfy . ldm . flux . math import apply_rope1
1211
1312def get_timestep_embedding (
1413 timesteps : torch .Tensor ,
@@ -238,20 +237,6 @@ def forward(self, x):
238237 return self .net (x )
239238
240239
241- def apply_rotary_emb (input_tensor , freqs_cis ): #TODO: remove duplicate funcs and pick the best/fastest one
242- cos_freqs = freqs_cis [0 ]
243- sin_freqs = freqs_cis [1 ]
244-
245- t_dup = rearrange (input_tensor , "... (d r) -> ... d r" , r = 2 )
246- t1 , t2 = t_dup .unbind (dim = - 1 )
247- t_dup = torch .stack ((- t2 , t1 ), dim = - 1 )
248- input_tensor_rot = rearrange (t_dup , "... d r -> ... (d r)" )
249-
250- out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
251-
252- return out
253-
254-
255240class CrossAttention (nn .Module ):
256241 def __init__ (self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0. , attn_precision = None , dtype = None , device = None , operations = None ):
257242 super ().__init__ ()
@@ -281,8 +266,8 @@ def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
281266 k = self .k_norm (k )
282267
283268 if pe is not None :
284- q = apply_rotary_emb ( q , pe )
285- k = apply_rotary_emb ( k , pe )
269+ q = apply_rope1 ( q . unsqueeze ( 1 ) , pe ). squeeze ( 1 )
270+ k = apply_rope1 ( k . unsqueeze ( 1 ) , pe ). squeeze ( 1 )
286271
287272 if mask is None :
288273 out = comfy .ldm .modules .attention .optimized_attention (q , k , v , self .heads , attn_precision = self .attn_precision , transformer_options = transformer_options )
@@ -306,12 +291,17 @@ def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None,
306291 def forward (self , x , context = None , attention_mask = None , timestep = None , pe = None , transformer_options = {}):
307292 shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp = (self .scale_shift_table [None , None ].to (device = x .device , dtype = x .dtype ) + timestep .reshape (x .shape [0 ], timestep .shape [1 ], self .scale_shift_table .shape [0 ], - 1 )).unbind (dim = 2 )
308293
309- x += self .attn1 (comfy .ldm .common_dit .rms_norm (x ) * (1 + scale_msa ) + shift_msa , pe = pe , transformer_options = transformer_options ) * gate_msa
294+ norm_x = comfy .ldm .common_dit .rms_norm (x )
295+ attn1_input = torch .addcmul (norm_x , norm_x , scale_msa ).add_ (shift_msa )
296+ attn1_result = self .attn1 (attn1_input , pe = pe , transformer_options = transformer_options )
297+ x .addcmul_ (attn1_result , gate_msa )
310298
311299 x += self .attn2 (x , context = context , mask = attention_mask , transformer_options = transformer_options )
312300
313- y = comfy .ldm .common_dit .rms_norm (x ) * (1 + scale_mlp ) + shift_mlp
314- x += self .ff (y ) * gate_mlp
301+ norm_x = comfy .ldm .common_dit .rms_norm (x )
302+ y = torch .addcmul (norm_x , norm_x , scale_mlp ).add_ (shift_mlp )
303+ ff_result = self .ff (y )
304+ x .addcmul_ (ff_result , gate_mlp )
315305
316306 return x
317307
@@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):
327317
328318
329319def precompute_freqs_cis (indices_grid , dim , out_dtype , theta = 10000.0 , max_pos = [20 , 2048 , 2048 ]):
330- dtype = torch .float32 #self.dtype
320+ dtype = torch .float32
321+ device = indices_grid .device
331322
323+ # Get fractional positions and compute frequency indices
332324 fractional_positions = get_fractional_positions (indices_grid , max_pos )
325+ indices = theta ** torch .linspace (0 , 1 , dim // 6 , device = device , dtype = dtype ) * math .pi / 2
333326
334- start = 1
335- end = theta
336- device = fractional_positions .device
337-
338- indices = theta ** (
339- torch .linspace (
340- math .log (start , theta ),
341- math .log (end , theta ),
342- dim // 6 ,
343- device = device ,
344- dtype = dtype ,
345- )
346- )
347- indices = indices .to (dtype = dtype )
327+ # Compute frequencies and apply cos/sin
328+ freqs = (indices * (fractional_positions .unsqueeze (- 1 ) * 2 - 1 )).transpose (- 1 , - 2 ).flatten (2 )
329+ cos_vals = freqs .cos ().repeat_interleave (2 , dim = - 1 )
330+ sin_vals = freqs .sin ().repeat_interleave (2 , dim = - 1 )
331+
332+ # Pad if dim is not divisible by 6
333+ if dim % 6 != 0 :
334+ padding_size = dim % 6
335+ cos_vals = torch .cat ([torch .ones_like (cos_vals [:, :, :padding_size ]), cos_vals ], dim = - 1 )
336+ sin_vals = torch .cat ([torch .zeros_like (sin_vals [:, :, :padding_size ]), sin_vals ], dim = - 1 )
348337
349- indices = indices * math .pi / 2
338+ # Reshape and extract one value per pair (since repeat_interleave duplicates each value)
339+ cos_vals = cos_vals .reshape (* cos_vals .shape [:2 ], - 1 , 2 )[..., 0 ] # [B, N, dim//2]
340+ sin_vals = sin_vals .reshape (* sin_vals .shape [:2 ], - 1 , 2 )[..., 0 ] # [B, N, dim//2]
350341
351- freqs = (
352- ( indices * ( fractional_positions . unsqueeze ( - 1 ) * 2 - 1 ))
353- . transpose ( - 1 , - 2 )
354- . flatten ( 2 )
355- )
342+ # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
343+ freqs_cis = torch . stack ([
344+ torch . stack ([ cos_vals , - sin_vals ], dim = - 1 ),
345+ torch . stack ([ sin_vals , cos_vals ], dim = - 1 )
346+ ], dim = - 2 ). unsqueeze ( 1 ) # [B, 1, N, dim//2, 2, 2]
356347
357- cos_freq = freqs .cos ().repeat_interleave (2 , dim = - 1 )
358- sin_freq = freqs .sin ().repeat_interleave (2 , dim = - 1 )
359- if dim % 6 != 0 :
360- cos_padding = torch .ones_like (cos_freq [:, :, : dim % 6 ])
361- sin_padding = torch .zeros_like (cos_freq [:, :, : dim % 6 ])
362- cos_freq = torch .cat ([cos_padding , cos_freq ], dim = - 1 )
363- sin_freq = torch .cat ([sin_padding , sin_freq ], dim = - 1 )
364- return cos_freq .to (out_dtype ), sin_freq .to (out_dtype )
348+ return freqs_cis .to (out_dtype )
365349
366350
367351class LTXVModel (torch .nn .Module ):
@@ -501,7 +485,7 @@ def block_wrap(args):
501485 shift , scale = scale_shift_values [:, :, 0 ], scale_shift_values [:, :, 1 ]
502486 x = self .norm_out (x )
503487 # Modulation
504- x = x * ( 1 + scale ) + shift
488+ x = torch . addcmul ( x , x , scale ). add_ ( shift )
505489 x = self .proj_out (x )
506490
507491 x = self .patchifier .unpatchify (
0 commit comments