|
9 | 9 | from comfy.ldm.modules.attention import optimized_attention |
10 | 10 | from comfy.ldm.flux.layers import EmbedND |
11 | 11 | from comfy.ldm.flux.math import apply_rope |
12 | | -from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm |
13 | 12 | import comfy.ldm.common_dit |
14 | 13 | import comfy.model_management |
15 | 14 |
|
@@ -49,8 +48,8 @@ def __init__(self, |
49 | 48 | self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
50 | 49 | self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
51 | 50 | self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
52 | | - self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() |
53 | | - self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() |
| 51 | + self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() |
| 52 | + self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() |
54 | 53 |
|
55 | 54 | def forward(self, x, freqs): |
56 | 55 | r""" |
@@ -114,7 +113,7 @@ def __init__(self, |
114 | 113 | self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
115 | 114 | self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) |
116 | 115 | # self.alpha = nn.Parameter(torch.zeros((1, ))) |
117 | | - self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() |
| 116 | + self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity() |
118 | 117 |
|
119 | 118 | def forward(self, x, context, context_img_len): |
120 | 119 | r""" |
|
0 commit comments