From 09af7413345c99560a83b1fb83127beef18bd21e Mon Sep 17 00:00:00 2001 From: hem9984 <147702557+hem9984@users.noreply.github.com> Date: Mon, 22 Jul 2024 21:30:24 -0400 Subject: [PATCH] fixed layer_norm fixed error that occurs for some models --- clip/model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/clip/model.py b/clip/model.py index 232b7792e..e388d402d 100644 --- a/clip/model.py +++ b/clip/model.py @@ -154,15 +154,16 @@ def stem(x): return x -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" +class LayerNorm(nn.Module): + def __init__(self, *args, **kwargs): + super(LayerNorm, self).__init__() + self.inner_layernorm = nn.LayerNorm(*args, **kwargs) def forward(self, x: torch.Tensor): orig_type = x.dtype - ret = super().forward(x.type(torch.float32)) + ret = self.inner_layernorm(x.type(torch.float32)) return ret.type(orig_type) - class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x)