-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Kernel] Support rms_norm kernel for Gemma #29810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Xin Yang <[email protected]>
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for Gemma's RMSNorm by introducing a weight_bias parameter to the existing RMSNorm kernels. This is a clean approach to extend the functionality. The changes are well-integrated across the CUDA kernels, Python bindings, and layer implementations, and are accompanied by appropriate benchmarks and tests. The performance improvement for Gemma is substantial, as it now leverages the custom CUDA kernel instead of the native PyTorch implementation. I have identified one potential performance regression in the vectorized fused_add_rms_norm kernel for the standard RMSNorm case and have provided a suggestion to mitigate it.
| if vllm_is_batch_invariant(): | ||
| return rms_norm_batch_invariant(x, weight, variance_epsilon) | ||
| out = torch.empty_like(x) | ||
| out = torch.empty(x.shape, dtype=x.dtype, device=x.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is to make sure out tensor is contiguous.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should add this in comment
| variance = x.pow(2).mean(dim=-1, keepdim=True) | ||
| x = x * torch.rsqrt(variance + self.variance_epsilon) | ||
| x = x.to(orig_dtype) * self.weight | ||
| x = x.to(orig_dtype) * (self.weight_bias + self.weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Randomly stumbled upon this PR, and without a full in-depth review, I just wanted to mention that it is essential that weight_bias and weight are added in fp32. In certain gemma models, the normalization layers have weights that -- when added by the weight bias (1.0), produce incorrect results.
For example:
>>> import torch
>>> weight = torch.tensor([15.8125], dtype=torch.bfloat16)
>>> weight + 1.0
tensor([16.7500], dtype=torch.bfloat16)
>>> weight.float() + 1.0
tensor([16.8125])Not sure if this is considered here already, but just wanted to flag in case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I'll make the change. Thanks for the comment!
|
@xyang16 I was wondering if |
Purpose
This PR add rms_norm CUDA kernel support for Gemma
y = x / sqrt(mean(x^2) + epsilon) * weighty = x / sqrt(mean(x^2) + epsilon) * (weight + 1.0)weight_bias = 0.0for RMSNorm,weight_bias = 1.0for GemmaRMSNormTest Plan
Test Result
Unit tests passed.
Microbenchmark
Microbench numbers show more than 2x improvement, because previously GemmaRMSNorm
forward_cudacallsforward_native.gemma rms_norm
Accuracy Testing
Baseline:
PR:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.