Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions comfy/text_encoders/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Llama2Config:
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None

@dataclass
class Qwen25_3BConfig:
Expand All @@ -44,6 +45,7 @@ class Qwen25_3BConfig:
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = True
rope_dims = None

@dataclass
class Qwen25_7BVLI_Config:
Expand All @@ -61,6 +63,7 @@ class Qwen25_7BVLI_Config:
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = True
rope_dims = [16, 24, 24]

@dataclass
class Gemma2_2B_Config:
Expand All @@ -78,6 +81,7 @@ class Gemma2_2B_Config:
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None

class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
Expand All @@ -102,7 +106,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def precompute_freqs_cis(head_dim, position_ids, theta, device=None):
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))

Expand All @@ -112,12 +116,20 @@ def precompute_freqs_cis(head_dim, position_ids, theta, device=None):
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)

return (cos, sin)


def apply_rope(xq, xk, freqs_cis):
cos = freqs_cis[0].unsqueeze(1)
sin = freqs_cis[1].unsqueeze(1)
cos = freqs_cis[0]
sin = freqs_cis[1]
q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin)
return q_embed, k_embed
Expand Down Expand Up @@ -292,6 +304,7 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_dims,
device=x.device)

mask = None
Expand Down
Loading