-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Description
Summary
When using kernelize on a Llama-based model, only the LigerRMSNorm kernel causes numerical differences. The SiLU kernel produces identical outputs.
Results
| Config | Time | Speedup | Max Logit Diff | Argmax Match |
|---|---|---|---|---|
| Baseline (no kernelize) | 107.83ms | 1.00x | 0 | ✓ |
| Only RMSNorm | 97.69ms | 1.10x | 0.125 | ✓ |
| Only SiLU | 104.64ms | 1.03x | 0 | ✓ |
| Both | 94.40ms | 1.14x | 0.125 | ✓ |
Key Findings
-
LigerRMSNorm(fromkernels-community/liger_kernels):- ⚡ 10% speedup
⚠️ Max logit diff: 0.125- ✓ Argmax still matches (generation unaffected)
-
Silu(fromkernels-community/activation):- ⚡ 3% speedup
- ✅ Exact match (0 diff)
Environment
- Model:
sign/utf8-lm-tiny(Llama-based, ~70M params) - dtype:
torch.bfloat16 - GPU: NVIDIA GB10 (CUDA 12.1)
Minimal Reproduction
from transformers import AutoModelForCausalLM
import torch
from utf8_tokenizer import UTF8Tokenizer
from kernels import Mode
from kernels.layer.layer import kernelize_layer
from kernels.layer.device import Device
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.activations import SiLUActivation
model_id = "sign/utf8-lm-tiny"
device = "cuda"
dtype = torch.bfloat16
tokenizer = UTF8Tokenizer()
prompt = "Hello world! " * 9 # ~118 tokens
inputs = tokenizer([prompt], return_tensors="pt", padding=True, add_special_tokens=True)
inputs["input_ids"] = inputs["input_ids"].to(torch.long)[:, :-1].to(device)
inputs["attention_mask"] = inputs["attention_mask"][:, :-1].to(device)
# Baseline
model_base = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device).eval()
with torch.no_grad():
logits_base = model_base(**inputs).logits
# Only RMSNorm kernelized
model_rms = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
for name, module in model_rms.named_modules():
if isinstance(module, LlamaRMSNorm):
kernelize_layer(module, mode=Mode.INFERENCE, device_type=Device(type="cuda"), use_fallback=True)
model_rms.eval()
with torch.no_grad():
logits_rms = model_rms(**inputs).logits
# Only SiLU kernelized
model_silu = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to(device)
for name, module in model_silu.named_modules():
if isinstance(module, SiLUActivation):
kernelize_layer(module, mode=Mode.INFERENCE, device_type=Device(type="cuda"), use_fallback=True)
model_silu.eval()
with torch.no_grad():
logits_silu = model_silu(**inputs).logits
print(f"RMSNorm diff: {(logits_base - logits_rms).abs().max().item()}") # 0.125
print(f"SiLU diff: {(logits_base - logits_silu).abs().max().item()}") # 0.0Analysis
The LigerRMSNorm kernel uses a different numerical implementation than the original LlamaRMSNorm:
- Likely uses a fused kernel with different reduction order
- At
bfloat16precision, this causes small differences that accumulate through 9 norm layers
The differences are small enough that argmax predictions are unaffected, so generation results remain identical.
Recommendation
- ✅ SiLU kernel is bit-exact - safe to use, 3% speedup
- ❌ RMSNorm kernel causes differences - avoid if exact reproducibility is required
danieldk
Metadata
Metadata
Assignees
Labels
No labels