48
48
from vllm .sequence import SamplerOutput
49
49
50
50
51
+ @torch .compile
52
+ def layer_norm_func (hidden_states , weight , variance_epsilon ):
53
+ input_dtype = hidden_states .dtype
54
+ hidden_states = hidden_states .to (torch .float32 )
55
+ mean = hidden_states .mean (- 1 , keepdim = True )
56
+ variance = (hidden_states - mean ).pow (2 ).mean (- 1 , keepdim = True )
57
+ hidden_states = (hidden_states - mean ) * torch .rsqrt (variance +
58
+ variance_epsilon )
59
+ hidden_states = weight .to (torch .float32 ) * hidden_states
60
+ return hidden_states .to (input_dtype )
61
+
62
+
51
63
class LayerNorm (nn .Module ):
52
64
53
65
def __init__ (self , param_shape = None , eps = 1e-5 ):
@@ -57,14 +69,9 @@ def __init__(self, param_shape=None, eps=1e-5):
57
69
set_weight_attrs (self .weight , {"weight_loader" : self .weight_loader })
58
70
59
71
def forward (self , hidden_states , residuals = None ):
60
- input_dtype = hidden_states .dtype
61
- hidden_states = hidden_states .to (torch .float32 )
62
- mean = hidden_states .mean (- 1 , keepdim = True )
63
- variance = (hidden_states - mean ).pow (2 ).mean (- 1 , keepdim = True )
64
- hidden_states = (hidden_states -
65
- mean ) * torch .rsqrt (variance + self .variance_epsilon )
66
- hidden_states = self .weight .to (torch .float32 ) * hidden_states
67
- return hidden_states .to (input_dtype ), residuals
72
+ hidden_states = layer_norm_func (hidden_states , self .weight ,
73
+ self .variance_epsilon )
74
+ return hidden_states , residuals
68
75
69
76
def weight_loader (self , param : Parameter , loaded_weight : torch .Tensor ):
70
77
tp_rank = get_tensor_model_parallel_rank ()
0 commit comments