diff --git a/clip/model.py b/clip/model.py
index 232b7792e..e388d402d 100644
--- a/clip/model.py
+++ b/clip/model.py
@@ -154,15 +154,16 @@ def stem(x):
         return x
 
 
-class LayerNorm(nn.LayerNorm):
-    """Subclass torch's LayerNorm to handle fp16."""
+class LayerNorm(nn.Module):
+    def __init__(self, *args, **kwargs):
+        super(LayerNorm, self).__init__()
+        self.inner_layernorm = nn.LayerNorm(*args, **kwargs)
 
     def forward(self, x: torch.Tensor):
         orig_type = x.dtype
-        ret = super().forward(x.type(torch.float32))
+        ret = self.inner_layernorm(x.type(torch.float32))
         return ret.type(orig_type)
 
-
 class QuickGELU(nn.Module):
     def forward(self, x: torch.Tensor):
         return x * torch.sigmoid(1.702 * x)