Skip to content

Commit 3041e5c

Browse files
Switch mochi and wan modes to use pytorch RMSNorm. (Comfy-Org#7925)
* Switch genmo model to native RMSNorm. * Switch WAN to native RMSNorm.
1 parent 7689917 commit 3041e5c

File tree

3 files changed

+7
-20
lines changed

3 files changed

+7
-20
lines changed

comfy/ldm/genmo/joint_model/asymm_models_joint.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from .layers import (
1414
FeedForward,
1515
PatchEmbed,
16-
RMSNorm,
1716
TimestepEmbedder,
1817
)
1918

@@ -90,10 +89,10 @@ def __init__(
9089

9190
# Query and key normalization for stability.
9291
assert qk_norm
93-
self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
94-
self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
95-
self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
96-
self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
92+
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
93+
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
94+
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
95+
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
9796

9897
# Output layers. y features go back down from dim_x -> dim_y.
9998
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)

comfy/ldm/genmo/joint_model/layers.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,3 @@ def forward(self, x):
151151

152152
x = self.norm(x)
153153
return x
154-
155-
156-
class RMSNorm(torch.nn.Module):
157-
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
158-
super().__init__()
159-
self.eps = eps
160-
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
161-
self.register_parameter("bias", None)
162-
163-
def forward(self, x):
164-
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)

comfy/ldm/wan/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from comfy.ldm.modules.attention import optimized_attention
1010
from comfy.ldm.flux.layers import EmbedND
1111
from comfy.ldm.flux.math import apply_rope
12-
from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
1312
import comfy.ldm.common_dit
1413
import comfy.model_management
1514

@@ -49,8 +48,8 @@ def __init__(self,
4948
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
5049
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
5150
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
52-
self.norm_q = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
53-
self.norm_k = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
51+
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
52+
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
5453

5554
def forward(self, x, freqs):
5655
r"""
@@ -114,7 +113,7 @@ def __init__(self,
114113
self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
115114
self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
116115
# self.alpha = nn.Parameter(torch.zeros((1, )))
117-
self.norm_k_img = RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
116+
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
118117

119118
def forward(self, x, context, context_img_len):
120119
r"""

0 commit comments

Comments
 (0)