Skip to content

Conversation

@xyang16
Copy link
Contributor

@xyang16 xyang16 commented Dec 1, 2025

Purpose

This PR add rms_norm CUDA kernel support for Gemma

  • Add weight_bias param in rms_norm_kernel
    • RMSNorm computation: y = x / sqrt(mean(x^2) + epsilon) * weight
    • Gemma style RMSNorm computation: y = x / sqrt(mean(x^2) + epsilon) * (weight + 1.0)
    • Thus adding weight_bias: weight_bias = 0.0 for RMSNorm, weight_bias = 1.0 for GemmaRMSNorm

Test Plan

  • Added test cases in test_layernorm.py
pytest -s -v tests/kernels/core/test_layernorm.py

Test Result

Unit tests passed.

Microbenchmark

  • Microbench numbers show more than 2x improvement, because previously GemmaRMSNorm forward_cuda calls forward_native.

  • gemma rms_norm

python3 benchmarks/kernels/benchmark_rmsnorm.py --is-gemma
    head_num  batch_size  seq_len   HuggingFace    FlashInfer          vLLM
0       32.0         1.0     64.0     44.032000     13.312000     13.312000
1       32.0         1.0    128.0     59.392001     19.455999     18.432001
2       32.0         1.0    256.0     88.064000     34.784000     32.768000
3       32.0         1.0    512.0    152.575999     58.368001     58.368001
4       32.0         1.0   1024.0    280.575991    107.519999    109.568000
5       32.0         4.0     64.0     89.088000     33.792000     32.768000
6       32.0         4.0    128.0    152.575999     58.368001     58.368001
7       32.0         4.0    256.0    280.575991    108.543999    109.600000
8       32.0         4.0    512.0    903.168023    215.039998    216.064006
9       32.0         4.0   1024.0   2779.647946    497.664005    492.543995
10      32.0        16.0     64.0    280.575991    108.543999    109.568000
11      32.0        16.0    128.0    903.728008    215.039998    215.552002
12      32.0        16.0    256.0   2776.575923    498.688012    489.488006
13      32.0        16.0    512.0   5749.759912   1157.119989   1157.119989
14      32.0        16.0   1024.0  11470.415592   2311.167955   2309.119940
15      32.0        64.0     64.0   2776.639938    499.711990    493.568003
16      32.0        64.0    128.0   5747.712135   1160.192013   1159.168005
17      32.0        64.0    256.0  11487.744331   2308.095932   2309.119940
18      32.0        64.0    512.0  22841.343880   4592.639923   4606.976032
19      32.0        64.0   1024.0  45801.473618   9177.087784   9201.663971
20      48.0         1.0     64.0     55.296000     16.384000     16.224001
21      48.0         1.0    128.0     74.752003     27.648000     27.648000
22      48.0         1.0    256.0    122.879997     45.056000     45.056000
23      48.0         1.0    512.0    220.159993     86.015999     87.040000
24      48.0         1.0   1024.0    476.159990    162.816003    165.984005
25      48.0         4.0     64.0    122.879997     46.080001     45.056000
26      48.0         4.0    128.0    220.159993     84.991999     86.015999
27      48.0         4.0    256.0    476.159990    162.816003    165.887997
28      48.0         4.0    512.0   1896.960020    335.871994    334.847987
29      48.0         4.0   1024.0   4300.271988    871.424019    871.424019
30      48.0        16.0     64.0    476.159990    162.816003    165.887997
31      48.0        16.0    128.0   1895.936012    336.896002    333.824009
32      48.0        16.0    256.0   4297.264099    870.400012    870.400012
33      48.0        16.0    512.0   8615.935802   1736.703992   1737.280011
34      48.0        16.0   1024.0  17172.479630   3447.808027   3446.784019
35      48.0        64.0     64.0   4301.311970    871.424019    871.424019
36      48.0        64.0    128.0   8621.567726   1738.240004   1735.679984
37      48.0        64.0    256.0  17164.287567   3449.344039   3448.320031
38      48.0        64.0    512.0  34331.136703   6900.223970   6909.952164
39      48.0        64.0   1024.0  68999.038696  13754.367828  13806.080341
  • gemma fused_add_rms_norm
python3 benchmarks/kernels/benchmark_rmsnorm.py --is-gemma --use-residual
    head_num  batch_size  seq_len    HuggingFace    FlashInfer          vLLM
0       32.0         1.0     64.0      60.416002     19.455999     20.479999
1       32.0         1.0    128.0      89.088000     28.672000     28.704001
2       32.0         1.0    256.0     138.239995     53.247999     53.247999
3       32.0         1.0    512.0     273.407996     93.184002     94.208002
4       32.0         1.0   1024.0     557.056010    174.079999    174.079999
5       32.0         4.0     64.0     138.239995     55.296000     53.247999
6       32.0         4.0    128.0     270.336002     92.160001     92.160001
7       32.0         4.0    256.0     559.104025    173.056006    174.079999
8       32.0         4.0    512.0    1875.455976    360.448003    360.448003
9       32.0         4.0   1024.0    4798.463821   1131.520033   1128.448009
10      32.0        16.0     64.0     559.104025    173.056006    173.056006
11      32.0        16.0    128.0    1890.303969    360.448003    359.423995
12      32.0        16.0    256.0    4795.392036   1132.544041   1126.399994
13      32.0        16.0    512.0    9782.272339   2315.263987   2303.983927
14      32.0        16.0   1024.0   19501.055717   4631.552219   4587.520123
15      32.0        64.0     64.0    4793.344021   1133.056045   1128.960013
16      32.0        64.0    128.0    9764.863968   2314.239979   2306.047916
17      32.0        64.0    256.0   19505.151749   4628.479958   4584.447861
18      32.0        64.0    512.0   39071.231842   9263.104439   9165.823936
19      32.0        64.0   1024.0   77983.741760  18477.056503  18299.903870
20      48.0         1.0     64.0      77.823997     24.576001     25.599999
21      48.0         1.0    128.0     114.688002     41.983999     40.959999
22      48.0         1.0    256.0     210.943997     72.768003     73.728003
23      48.0         1.0    512.0     395.264000    133.120000    135.168001
24      48.0         1.0   1024.0    1080.384016    256.000012    256.000012
25      48.0         4.0     64.0     210.943997     72.704002     73.728003
26      48.0         4.0    128.0     394.239992    133.120000    134.143993
27      48.0         4.0    256.0    1081.344008    254.976004    256.000012
28      48.0         4.0    512.0    3387.904048    712.704003    718.847990
29      48.0         4.0   1024.0    7306.239843   1739.776015   1735.167980
30      48.0        16.0     64.0    1080.320001    256.000012    254.976004
31      48.0        16.0    128.0    3390.464067    712.704003    720.896006
32      48.0        16.0    256.0    7316.479921   1737.215996   1737.215996
33      48.0        16.0    512.0   14653.952122   3483.648062   3464.671969
34      48.0        16.0   1024.0   29324.287415   6965.247869   6934.528112
35      48.0        64.0     64.0    7311.360121   1738.752007   1741.824031
36      48.0        64.0    128.0   14651.904106   3487.232089   3457.536101
37      48.0        64.0    256.0   29300.735474   6975.488186   6922.239780
38      48.0        64.0    512.0   58565.631866  13866.496086  13803.519726
39      48.0        64.0   1024.0  116948.829651  27778.047562  27550.720215

Accuracy Testing

vllm serve unsloth/gemma-3-27b-it \
    --tensor-parallel-size 8 \
    --pipeline-parallel-size 1 \
    --max-num-seqs 32 \
    --compilation-config '{"custom_ops":["+gemma_rms_norm"]}'
python3 -m lm_eval --model local-completions \
  --model_args model=unsloth/gemma-3-27b-it,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=32 \
  --tasks gsm8k

Baseline:

local-completions (model=unsloth/gemma-3-27b-it,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=32), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8560|±  |0.0097|
|     |       |strict-match    |     5|exact_match|↑  |0.8484|±  |0.0099|

PR:

local-completions (model=unsloth/gemma-3-27b-it,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=32), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8560|±  |0.0097|
|     |       |strict-match    |     5|exact_match|↑  |0.8484|±  |0.0099|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@mergify mergify bot added the performance Performance-related issues label Dec 1, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Gemma's RMSNorm by introducing a weight_bias parameter to the existing RMSNorm kernels. This is a clean approach to extend the functionality. The changes are well-integrated across the CUDA kernels, Python bindings, and layer implementations, and are accompanied by appropriate benchmarks and tests. The performance improvement for Gemma is substantial, as it now leverages the custom CUDA kernel instead of the native PyTorch implementation. I have identified one potential performance regression in the vectorized fused_add_rms_norm kernel for the standard RMSNorm case and have provided a suggestion to mitigate it.

if vllm_is_batch_invariant():
return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x)
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is to make sure out tensor is contiguous.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should add this in comment

@xyang16 xyang16 changed the title [Kernle] Support rms_norm kernel for Gemma [Kernel] Support rms_norm kernel for Gemma Dec 1, 2025
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
x = x.to(orig_dtype) * (self.weight_bias + self.weight)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Randomly stumbled upon this PR, and without a full in-depth review, I just wanted to mention that it is essential that weight_bias and weight are added in fp32. In certain gemma models, the normalization layers have weights that -- when added by the weight bias (1.0), produce incorrect results.

For example:

>>> import torch
>>> weight = torch.tensor([15.8125], dtype=torch.bfloat16)
>>> weight + 1.0
tensor([16.7500], dtype=torch.bfloat16)
>>> weight.float() + 1.0
tensor([16.8125])

Not sure if this is considered here already, but just wanted to flag in case.

https://github.com/huggingface/transformers/blob/0fa49db1205e0a2745161ccac46184e7e46b6e2b/src/transformers/models/gemma/modular_gemma.py#L204

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I'll make the change. Thanks for the comment!

@ZJY0516
Copy link
Contributor

ZJY0516 commented Dec 2, 2025

@xyang16 I was wondering if torch.compile has the capability to produce efficient kernels?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants