Skip to content

Commit 4cd8818

Browse files
authored
Use single apply_rope function across models (Comfy-Org#10547)
1 parent 265adad commit 4cd8818

File tree

5 files changed

+58
-79
lines changed

5 files changed

+58
-79
lines changed

comfy/ldm/flux/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
195195
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
196196

197197
# calculate the img bloks
198-
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
199-
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
198+
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
199+
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
200200

201201
# calculate the txt bloks
202202
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)

comfy/ldm/flux/math.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,7 @@
77

88

99
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
10-
q_shape = q.shape
11-
k_shape = k.shape
12-
13-
if pe is not None:
14-
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
15-
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
16-
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
17-
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
18-
10+
q, k = apply_rope(q, k, pe)
1911
heads = q.shape[1]
2012
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
2113
return x

comfy/ldm/lightricks/model.py

Lines changed: 35 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
import comfy.patcher_extension
44
import comfy.ldm.modules.attention
55
import comfy.ldm.common_dit
6-
from einops import rearrange
76
import math
87
from typing import Dict, Optional, Tuple
98

109
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
11-
10+
from comfy.ldm.flux.math import apply_rope1
1211

1312
def get_timestep_embedding(
1413
timesteps: torch.Tensor,
@@ -238,20 +237,6 @@ def forward(self, x):
238237
return self.net(x)
239238

240239

241-
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
242-
cos_freqs = freqs_cis[0]
243-
sin_freqs = freqs_cis[1]
244-
245-
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
246-
t1, t2 = t_dup.unbind(dim=-1)
247-
t_dup = torch.stack((-t2, t1), dim=-1)
248-
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
249-
250-
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
251-
252-
return out
253-
254-
255240
class CrossAttention(nn.Module):
256241
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
257242
super().__init__()
@@ -281,8 +266,8 @@ def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
281266
k = self.k_norm(k)
282267

283268
if pe is not None:
284-
q = apply_rotary_emb(q, pe)
285-
k = apply_rotary_emb(k, pe)
269+
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
270+
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
286271

287272
if mask is None:
288273
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
@@ -306,12 +291,17 @@ def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None,
306291
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
307292
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
308293

309-
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
294+
norm_x = comfy.ldm.common_dit.rms_norm(x)
295+
attn1_input = torch.addcmul(norm_x, norm_x, scale_msa).add_(shift_msa)
296+
attn1_result = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
297+
x.addcmul_(attn1_result, gate_msa)
310298

311299
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
312300

313-
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
314-
x += self.ff(y) * gate_mlp
301+
norm_x = comfy.ldm.common_dit.rms_norm(x)
302+
y = torch.addcmul(norm_x, norm_x, scale_mlp).add_(shift_mlp)
303+
ff_result = self.ff(y)
304+
x.addcmul_(ff_result, gate_mlp)
315305

316306
return x
317307

@@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):
327317

328318

329319
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
330-
dtype = torch.float32 #self.dtype
320+
dtype = torch.float32
321+
device = indices_grid.device
331322

323+
# Get fractional positions and compute frequency indices
332324
fractional_positions = get_fractional_positions(indices_grid, max_pos)
325+
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
333326

334-
start = 1
335-
end = theta
336-
device = fractional_positions.device
337-
338-
indices = theta ** (
339-
torch.linspace(
340-
math.log(start, theta),
341-
math.log(end, theta),
342-
dim // 6,
343-
device=device,
344-
dtype=dtype,
345-
)
346-
)
347-
indices = indices.to(dtype=dtype)
327+
# Compute frequencies and apply cos/sin
328+
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
329+
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
330+
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
331+
332+
# Pad if dim is not divisible by 6
333+
if dim % 6 != 0:
334+
padding_size = dim % 6
335+
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
336+
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
348337

349-
indices = indices * math.pi / 2
338+
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
339+
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2]
340+
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0] # [B, N, dim//2]
350341

351-
freqs = (
352-
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
353-
.transpose(-1, -2)
354-
.flatten(2)
355-
)
342+
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
343+
freqs_cis = torch.stack([
344+
torch.stack([cos_vals, -sin_vals], dim=-1),
345+
torch.stack([sin_vals, cos_vals], dim=-1)
346+
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
356347

357-
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
358-
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
359-
if dim % 6 != 0:
360-
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
361-
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
362-
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
363-
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
364-
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
348+
return freqs_cis.to(out_dtype)
365349

366350

367351
class LTXVModel(torch.nn.Module):
@@ -501,7 +485,7 @@ def block_wrap(args):
501485
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
502486
x = self.norm_out(x)
503487
# Modulation
504-
x = x * (1 + scale) + shift
488+
x = torch.addcmul(x, x, scale).add_(shift)
505489
x = self.proj_out(x)
506490

507491
x = self.patchifier.unpatchify(

comfy/ldm/qwen_image/model.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from comfy.ldm.flux.layers import EmbedND
1111
import comfy.ldm.common_dit
1212
import comfy.patcher_extension
13+
from comfy.ldm.flux.math import apply_rope1
1314

1415
class GELU(nn.Module):
1516
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
@@ -134,33 +135,34 @@ def forward(
134135
image_rotary_emb: Optional[torch.Tensor] = None,
135136
transformer_options={},
136137
) -> Tuple[torch.Tensor, torch.Tensor]:
138+
batch_size = hidden_states.shape[0]
139+
seq_img = hidden_states.shape[1]
137140
seq_txt = encoder_hidden_states.shape[1]
138141

139-
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
140-
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
141-
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
142+
# Project and reshape to BHND format (batch, heads, seq, dim)
143+
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
144+
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
145+
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
142146

143-
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
144-
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
145-
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
147+
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
148+
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
149+
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
146150

147151
img_query = self.norm_q(img_query)
148152
img_key = self.norm_k(img_key)
149153
txt_query = self.norm_added_q(txt_query)
150154
txt_key = self.norm_added_k(txt_key)
151155

152-
joint_query = torch.cat([txt_query, img_query], dim=1)
153-
joint_key = torch.cat([txt_key, img_key], dim=1)
154-
joint_value = torch.cat([txt_value, img_value], dim=1)
156+
joint_query = torch.cat([txt_query, img_query], dim=2)
157+
joint_key = torch.cat([txt_key, img_key], dim=2)
158+
joint_value = torch.cat([txt_value, img_value], dim=2)
155159

156-
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
157-
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
160+
joint_query = apply_rope1(joint_query, image_rotary_emb)
161+
joint_key = apply_rope1(joint_key, image_rotary_emb)
158162

159-
joint_query = joint_query.flatten(start_dim=2)
160-
joint_key = joint_key.flatten(start_dim=2)
161-
joint_value = joint_value.flatten(start_dim=2)
162-
163-
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
163+
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
164+
attention_mask, transformer_options=transformer_options,
165+
skip_reshape=True)
164166

165167
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
166168
img_attn_output = joint_hidden_states[:, seq_txt:, :]
@@ -413,7 +415,7 @@ def _forward(
413415
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
414416
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
415417
ids = torch.cat((txt_ids, img_ids), dim=1)
416-
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
418+
image_rotary_emb = self.pe_embedder(ids).to(torch.float32).contiguous()
417419
del ids, txt_ids, img_ids
418420

419421
hidden_states = self.img_in(hidden_states)

comfy/ldm/wan/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ def forward(
232232
# assert e[0].dtype == torch.float32
233233

234234
# self-attention
235+
x = x.contiguous() # otherwise implicit in LayerNorm
235236
y = self.self_attn(
236237
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
237238
freqs, transformer_options=transformer_options)

0 commit comments

Comments
 (0)