Skip to content

Commit caada5e

Browse files
authored
[Core][Model] torch.compile for layernorm in commandr (vllm-project#3985)
[Core][Model] Use torch.compile to accelerate layernorm in commandr (vllm-project#3985)
1 parent 67b4221 commit caada5e

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

vllm/model_executor/models/commandr.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@
4848
from vllm.sequence import SamplerOutput
4949

5050

51+
@torch.compile
52+
def layer_norm_func(hidden_states, weight, variance_epsilon):
53+
input_dtype = hidden_states.dtype
54+
hidden_states = hidden_states.to(torch.float32)
55+
mean = hidden_states.mean(-1, keepdim=True)
56+
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
57+
hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
58+
variance_epsilon)
59+
hidden_states = weight.to(torch.float32) * hidden_states
60+
return hidden_states.to(input_dtype)
61+
62+
5163
class LayerNorm(nn.Module):
5264

5365
def __init__(self, param_shape=None, eps=1e-5):
@@ -57,14 +69,9 @@ def __init__(self, param_shape=None, eps=1e-5):
5769
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
5870

5971
def forward(self, hidden_states, residuals=None):
60-
input_dtype = hidden_states.dtype
61-
hidden_states = hidden_states.to(torch.float32)
62-
mean = hidden_states.mean(-1, keepdim=True)
63-
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
64-
hidden_states = (hidden_states -
65-
mean) * torch.rsqrt(variance + self.variance_epsilon)
66-
hidden_states = self.weight.to(torch.float32) * hidden_states
67-
return hidden_states.to(input_dtype), residuals
72+
hidden_states = layer_norm_func(hidden_states, self.weight,
73+
self.variance_epsilon)
74+
return hidden_states, residuals
6875

6976
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
7077
tp_rank = get_tensor_model_parallel_rank()

0 commit comments

Comments
 (0)