Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions lightllm/models/llama/triton_kernel/fused_add_rmsnorm_inplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import torch
import triton
import triton.language as tl


@triton.jit
def _fwd_fused_add_rmsnorm(
original,
residual,
weight,
original_stride0,
original_stride1,
residual_stride0,
residual_stride1,
N, # number of columns in X
eps,
BLOCK_SIZE: tl.constexpr,
):
block_id = tl.program_id(0)
# data's base address of this block
_original = original + block_id * original_stride0
_residual = residual + block_id * residual_stride0

# avoid repeat loading from gmem to smem
# in some very large size, have better performance
if N <= BLOCK_SIZE:
# data's offset address of this block
range = tl.arange(0, BLOCK_SIZE)
_original_offset = range * original_stride1
_residual_offset = range * residual_stride1
_weight_offset = range

# data's pointers of this block
_original_ptr = _original + _original_offset
_residual_ptr = _residual + _residual_offset
_weight_ptr = weight + _weight_offset

# load data from memory
mask = range < N
original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32)
residual_cache = tl.load(_residual_ptr, mask=mask, other=0.0).to(tl.float32)
weight_cache = tl.load(_weight_ptr, mask=mask, other=0.0).to(tl.float32)

# store (original + residual) to original
original_cache = original_cache + residual_cache
tl.store(_original_ptr, original_cache.to(original.dtype.element_ty), mask=mask)

# compute variance
var = tl.sum(original_cache * original_cache) / N
rstd = 1 / tl.sqrt(var + eps)
residual_cache = original_cache * rstd * weight_cache

# store rmsnorm(original + residual) back to residual
tl.store(_residual_ptr, residual_cache.to(residual.dtype.element_ty), mask=mask)
else:
sum_of_squares = tl.zeros([], dtype=tl.float32)
for block_offset in range(0, N, BLOCK_SIZE):
# data's offset address of this block
range = tl.arange(0, BLOCK_SIZE) + block_offset
_original_offset = range * original_stride1
_residual_offset = range * residual_stride1

# data's pointers of this block
_original_ptr = _original + _original_offset
_residual_ptr = _residual + _residual_offset

# load data from memory
mask = range < N
original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32)
residual_cache = tl.load(_residual_ptr, mask=mask, other=0.0).to(tl.float32)

# store (original + residual) to original
original_cache = original_cache + residual_cache
tl.store(_original_ptr, original_cache.to(original.dtype.element_ty), mask=mask)

# compute sum_of_squares
sum_of_squares += tl.sum(original_cache * original_cache)

# compute variance
var = sum_of_squares / N
rstd = 1 / tl.sqrt(var + eps)

for block_offset in range(0, N, BLOCK_SIZE):
# data's offset address of this block
range = tl.arange(0, BLOCK_SIZE) + block_offset
_original_offset = range * original_stride1
_residual_offset = range * residual_stride1
_weight_offset = range

# data's pointers of this block
_original_ptr = _original + _original_offset
_residual_ptr = _residual + _residual_offset
_weight_ptr = weight + _weight_offset

# load data from memory
mask = range < N
original_cache = tl.load(_original_ptr, mask=mask, other=0.0).to(tl.float32)
weight_cache = tl.load(_weight_ptr, mask=mask, other=0.0).to(tl.float32)

# apply rmsnorm using pre-computed rstd
original_cache = original_cache * rstd * weight_cache

# store rmsnorm(original) back to residual
tl.store(_residual_ptr, original_cache.to(residual.dtype.element_ty), mask=mask)


def fused_add_rmsnorm_inplace(
original: torch.Tensor, # [num_tokens, hidden_size]
residual: torch.Tensor,
weight: torch.Tensor,
eps: float,
):
"""
Perform fused add & rmsnorm

suppose the skip connection result is H(x) = F(x) + x,
then F(x) is the residual, x is the original.
Here original will be (residual + original), residual will be rmsnorm(residual + original)
At first Layer, residual should be all zeros.
"""
# reshape input data into 2D tensor
original_arg = original.view(-1, original.shape[-1])
residual_arg = residual.view(-1, residual.shape[-1])

assert original.data_ptr() == original_arg.data_ptr()
assert residual.data_ptr() == residual_arg.data_ptr()

M, N = original_arg.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // original.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))

if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")

# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 4)
num_warps = triton.next_power_of_2(num_warps)
if BLOCK_SIZE > 16384:
BLOCK_SIZE = 16384

# enqueue kernel
_fwd_fused_add_rmsnorm[(M,)](
original_arg,
residual_arg,
weight,
original_arg.stride(0),
original_arg.stride(1),
residual_arg.stride(0),
residual_arg.stride(1),
N, # number of columns in X
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
125 changes: 124 additions & 1 deletion lightllm/utils/custom_kernel_utis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import triton
import triton.language as tl
from typing import List
from typing import List, Callable


def custom_cat(tensors):
Expand Down Expand Up @@ -125,3 +125,126 @@ def pad2dim_tensor_to_new_batch(input: torch.Tensor, new_batch_size: int):
out[0:origin_batch_size, :] = input
out[origin_batch_size:, :] = input[0:1, :]
return out


def error(y_pred: torch.Tensor, y_real: torch.Tensor) -> torch.Tensor:
"""
Compute SNR between y_pred(tensor) and y_real(tensor)

SNR can be calcualted as following equation:

SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2

if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.

SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)

Args:
y_pred (torch.Tensor): _description_
y_real (torch.Tensor): _description_
reduction (str, optional): _description_. Defaults to 'mean'.

Raises:
ValueError: _description_
ValueError: _description_

Returns:
torch.Tensor: _description_
"""
y_pred = torch.flatten(y_pred).float()
y_real = torch.flatten(y_real).float()

if y_pred.shape != y_real.shape:
raise ValueError(
f"Can not compute snr loss for tensors with different shape. ({y_pred.shape} and {y_real.shape})"
)

noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
signal_power = torch.pow(y_real, 2).sum(dim=-1)
snr = (noise_power) / (signal_power + 1e-7)
return snr.item()


def benchmark(func: Callable, shape: List[int], tflops: float, steps: int, *args, **kwargs):
"""
A decorator function to assist in performance testing of CUDA operations.

This function will:
1. Automatically determine whether any parameters in the argument list,
or the output of the `func`, are of type `torch.Tensor`.
2. If so, calculate the memory usage of the input and output tensors
on the GPU (based on their data type and `torch.numel()`).
3. Establish a CUDA graph and attempt to execute `func` repeatedly for `steps` iterations.
4. Record the execution time during these iterations.
5. Use the information above to compute the compute performance (TFLOPS) and memory throughput.

Args:
func (function): The function to benchmark.
shape (list of int): The problem shape.
tflops (float): The computational workload (in TFLOPS) per call of `func`.
steps (int): The number of times the function is executed during benchmarking.
*args: Positional arguments to be passed to the `func`.
**kwargs: Keyword arguments to be passed to the `func`.

Returns:
function result
"""

# Ensure CUDA is available
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for benchmarking.")

# Check for torch.Tensor in inputs and outputs
input_tensors = [arg for arg in args if isinstance(arg, torch.Tensor)]
input_tensors += [value for value in kwargs.values() if isinstance(value, torch.Tensor)]

def calculate_memory(tensor: torch.Tensor):
"""Calculate memory usage in bytes for a tensor."""
return tensor.numel() * tensor.element_size()

input_memory = sum(calculate_memory(t) for t in input_tensors)

# Execute the function to inspect outputs
with torch.no_grad():
output = func(*args, **kwargs)

output_memory = 0
if isinstance(output, torch.Tensor):
output_memory = calculate_memory(output)
elif isinstance(output, (list, tuple)):
output_memory = sum(calculate_memory(o) for o in output if isinstance(o, torch.Tensor))

total_memory = input_memory + output_memory

# Warm-up and CUDA graph creation
for _ in range(10): # Warm-up
func(*args, **kwargs)

torch.cuda.synchronize() # Ensure no pending operations

# Benchmark the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
for _ in range(steps):
func(*args, **kwargs)
end_event.record()

torch.cuda.synchronize() # Ensure all operations are finished
elapsed_time_ms = start_event.elapsed_time(end_event) # Time in milliseconds

# Calculate performance metrics
elapsed_time_s = elapsed_time_ms / 1000 # Convert to seconds
avg_time_per_step = elapsed_time_s / steps
compute_performance = tflops / avg_time_per_step # TFLOPS
memory_throughput = (total_memory * steps / (1024 ** 3)) / elapsed_time_s # GB/s

# Print performance metrics
print(f"Function: {func.__name__}{shape}")
# print(f"Function: {func.__ne__}{shape}")
print(f"Elapsed Time (total): {elapsed_time_s:.4f} seconds")
print(f"Average Time Per Step: {avg_time_per_step * 1000:.3f} ms")
print(f"Compute Performance: {compute_performance:.2f} TFLOPS")
print(f"Memory Throughput: {memory_throughput:.2f} GB/s")
print("") # print a blank line.
63 changes: 63 additions & 0 deletions unit_tests/models/llama/test_fused_add_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import unittest
import torch
from lightllm.models.llama.triton_kernel.fused_add_rmsnorm_inplace import fused_add_rmsnorm_inplace
from lightllm.utils.custom_kernel_utis import benchmark, error


class TestFusedAddRmsNormInplace(unittest.TestCase):
def setUp(self):
"""Set up common test parameters."""
self.tokens = [1, 2, 3, 1024, 2048, 4096, 8192, 16384]
self.dims = [1, 2, 3, 512, 1024, 1025, 3200, 16384, 32768] # [512, 1024, 1032, 1536, 3200, 6144, 12800]
self.device = "cuda"
self.dtype = torch.bfloat16

def torch_add_rmsnorm(self, X, R, W):
X.add_(R)
return torch.nn.functional.rms_norm(X, (X.shape[1],), W, eps=1e-6)

def test_accuracy(self):
"""Test the accuracy of fused_add_rmsnorm_inplace against torch.rmsnorm."""
for token_num in self.tokens:
for dim in self.dims:
with self.subTest(shape=[token_num, dim]):
X = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype)
_X = X.clone()
R = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype)
_R = R.clone()
W = torch.randn(size=[dim], device=self.device, dtype=self.dtype)

r_real = self.torch_add_rmsnorm(_X, _R, W)
fused_add_rmsnorm_inplace(X, R, W, eps=1e-6)
r_pred = R
self.assertTrue(
error(r_pred, r_real) < 0.01,
f"Accuracy test failed for size {token_num}, {dim}. r_real={r_real}, r_pred={r_pred}",
)
print(f"{error(r_pred, r_real) = }")

x_real = _X
x_pred = X
self.assertTrue(
error(x_pred, x_real) < 0.01,
f"Accuracy test failed for size {token_num}, {dim}. x_real={x_real}, x_pred={x_pred}",
)
print(f"{error(x_pred, x_real) = }")

def test_performance(self):
"""Test the performance of rmsnorm using benchmark."""
for token_num in self.tokens:
for dim in self.dims:
with self.subTest(shape=[token_num, dim]):
X = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype)
R = torch.randn(size=[token_num, dim], device=self.device, dtype=self.dtype)
W = torch.randn(size=[dim], device=self.device, dtype=self.dtype)

shape = [token_num, dim]
tflops = 0.0
benchmark(self.torch_add_rmsnorm, shape, tflops, 100, X, R, W)
benchmark(fused_add_rmsnorm_inplace, shape, tflops, 100, X, R, W, eps=1e-6)


if __name__ == "__main__":
unittest.main()