Skip to content

Commit dfa791e

Browse files
Rope fix for qwen vl. (Comfy-Org#9435)
1 parent bddd696 commit dfa791e

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

comfy/text_encoders/llama.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class Llama2Config:
2727
rms_norm_add = False
2828
mlp_activation = "silu"
2929
qkv_bias = False
30+
rope_dims = None
3031

3132
@dataclass
3233
class Qwen25_3BConfig:
@@ -44,6 +45,7 @@ class Qwen25_3BConfig:
4445
rms_norm_add = False
4546
mlp_activation = "silu"
4647
qkv_bias = True
48+
rope_dims = None
4749

4850
@dataclass
4951
class Qwen25_7BVLI_Config:
@@ -61,6 +63,7 @@ class Qwen25_7BVLI_Config:
6163
rms_norm_add = False
6264
mlp_activation = "silu"
6365
qkv_bias = True
66+
rope_dims = [16, 24, 24]
6467

6568
@dataclass
6669
class Gemma2_2B_Config:
@@ -78,6 +81,7 @@ class Gemma2_2B_Config:
7881
rms_norm_add = True
7982
mlp_activation = "gelu_pytorch_tanh"
8083
qkv_bias = False
84+
rope_dims = None
8185

8286
class RMSNorm(nn.Module):
8387
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@@ -102,7 +106,7 @@ def rotate_half(x):
102106
return torch.cat((-x2, x1), dim=-1)
103107

104108

105-
def precompute_freqs_cis(head_dim, position_ids, theta, device=None):
109+
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
106110
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
107111
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
108112

@@ -112,12 +116,20 @@ def precompute_freqs_cis(head_dim, position_ids, theta, device=None):
112116
emb = torch.cat((freqs, freqs), dim=-1)
113117
cos = emb.cos()
114118
sin = emb.sin()
119+
if rope_dims is not None and position_ids.shape[0] > 1:
120+
mrope_section = rope_dims * 2
121+
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
122+
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
123+
else:
124+
cos = cos.unsqueeze(1)
125+
sin = sin.unsqueeze(1)
126+
115127
return (cos, sin)
116128

117129

118130
def apply_rope(xq, xk, freqs_cis):
119-
cos = freqs_cis[0].unsqueeze(1)
120-
sin = freqs_cis[1].unsqueeze(1)
131+
cos = freqs_cis[0]
132+
sin = freqs_cis[1]
121133
q_embed = (xq * cos) + (rotate_half(xq) * sin)
122134
k_embed = (xk * cos) + (rotate_half(xk) * sin)
123135
return q_embed, k_embed
@@ -292,6 +304,7 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
292304
freqs_cis = precompute_freqs_cis(self.config.head_dim,
293305
position_ids,
294306
self.config.rope_theta,
307+
self.config.rope_dims,
295308
device=x.device)
296309

297310
mask = None

0 commit comments

Comments
 (0)