Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ class RunResult:
"examples.kl_div",
"kl_div_tritonbench",
),
"kl_div-bwd": (
"tritonbench.operators.kl_div.operator",
"examples.kl_div",
"kl_div_tritonbench",
),
"ragged_attention": (
"tritonbench.operators.ragged_attention.operator",
"examples.jagged_hstu_attn",
Expand Down Expand Up @@ -410,6 +415,15 @@ class RunResult:
"helion_kl_div_tritonbench-speedup": "helion_speedup",
"helion_kl_div_tritonbench-accuracy": "helion_accuracy",
},
"kl_div-bwd": {
"torch_kl_div": "baseline",
"liger_kl_div-speedup": "triton_speedup",
"liger_kl_div-accuracy": "triton_accuracy",
"torch_compile_kl_div-speedup": "torch_compile_speedup",
"torch_compile_kl_div-accuracy": "torch_compile_accuracy",
"helion_kl_div_tritonbench-speedup": "helion_speedup",
"helion_kl_div_tritonbench-accuracy": "helion_accuracy",
},
"gather_gemv": {
"eager_gather_gemv": "baseline",
"triton_gather_gemv-speedup": "triton_speedup",
Expand Down
149 changes: 131 additions & 18 deletions examples/kl_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
# -------
from __future__ import annotations

import math
from typing import TYPE_CHECKING
from typing import Any

import torch
from torch import Tensor
Expand Down Expand Up @@ -117,6 +119,106 @@ def kl_div_forward(
return final_loss


@helion.kernel
def kl_div_backward(
grad_out: Tensor,
y_pred: Tensor, # input predictions in log-space, shape (BT, V)
y_true: Tensor, # target values, shape (BT, V)
log_target: hl.constexpr = False, # type: ignore[arg-type]
reduction: hl.constexpr = "batchmean", # type: ignore[arg-type]
eps: hl.constexpr = 1e-10, # type: ignore[arg-type]
compute_y_true_grad: hl.constexpr = True, # type: ignore[arg-type]
) -> tuple[Tensor, Tensor | None]:
BT, V = y_pred.shape
assert y_true.shape == y_pred.shape, (
f"Shape mismatch: {y_true.shape} != {y_pred.shape}"
)

grad_y_pred = torch.empty_like(y_pred)
if compute_y_true_grad:
grad_y_true = torch.empty_like(y_true)
else:
grad_y_true = None

if reduction == "none":
grad_out_expanded = grad_out
else:
grad_out_expanded = grad_out.expand(y_true.shape)

log_eps = math.log(eps)
for tile_bt in hl.tile(BT):
for tile_v in hl.tile(V):
grad_out_val = grad_out_expanded[tile_bt, tile_v]
y_true_val = y_true[tile_bt, tile_v]

if log_target:
y_true_exp = torch.exp(y_true_val)

if reduction == "batchmean":
div = BT
elif reduction == "mean":
div = BT * V
else: # reduction == "sum" or "none"
div = 1.0

if log_target:
grad_y_pred[tile_bt, tile_v] = -grad_out_val * y_true_exp / div # type: ignore[possibly-undefined]
else:
grad_y_pred[tile_bt, tile_v] = -grad_out_val * y_true_val / div

if compute_y_true_grad:
y_pred_val = y_pred[tile_bt, tile_v]
if log_target:
tmp = y_true_exp * (y_true_val - y_pred_val + 1) # type: ignore[possibly-undefined]
else:
lt_eps = log_eps - y_pred_val
gt_eps = torch.log(y_true_val) - y_pred_val + 1
tmp = torch.where(y_true_val < eps, lt_eps, gt_eps)

grad_y_true[tile_bt, tile_v] = grad_out_val * tmp / div # type: ignore[index]

return grad_y_pred, grad_y_true


class KLDivFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any, # noqa: ANN401
y_pred: Tensor, # input predictions in log-space, shape (BT, V)
y_true: Tensor, # target values, shape (BT, V)
log_target: bool,
reduction: str,
eps: float,
) -> Tensor:
"""Forward pass for KL divergence."""
loss = kl_div_forward(y_pred, y_true, log_target, reduction, eps)
ctx.save_for_backward(y_pred, y_true) # type: ignore[arg-type]
ctx.log_target = log_target
ctx.reduction = reduction
ctx.eps = eps
return loss

@staticmethod
def backward( # type: ignore[override]
ctx: Any, # noqa: ANN401
grad_out: Tensor,
) -> tuple[Tensor, Tensor | None, None, None, None]:
"""Backward pass for KL divergence."""
y_pred, y_true = ctx.saved_tensors # type: ignore[attr-defined]

grad_y_pred, grad_y_true = kl_div_backward(
grad_out,
y_pred,
y_true,
ctx.log_target,
ctx.reduction,
ctx.eps,
y_true.requires_grad,
)

return grad_y_pred, grad_y_true, None, None, None


# %%
# KL Divergence Loss Module
# -------------------------
Expand Down Expand Up @@ -154,7 +256,7 @@ def forward(self, input_tensor: Tensor, target_tensor: Tensor) -> Tensor:
Returns:
KL divergence loss
"""
return kl_div_forward(
return KLDivFunction.apply( # type: ignore[no-any-return]
input_tensor, target_tensor, self.log_target, self.reduction, self.eps
)

Expand All @@ -181,16 +283,26 @@ def check_kl_div_kernel(
log_target: Whether target is in log-space
eps: Small value for numerical stability
"""
# Create test tensors following tritonbench pattern
input_tensor = torch.randn(B * T, V, requires_grad=True, device="cuda").log_softmax(
dim=-1
)

target_tensor = torch.randn(B * T, V, device="cuda").softmax(dim=-1)

# Test forward pass
# Create test tensors following tritonbench pattern
def create_inputs() -> tuple[Tensor, Tensor]:
input_tensor = torch.randn(
B * T, V, requires_grad=True, device="cuda"
).log_softmax(dim=-1)
input_tensor.retain_grad()

target_tensor = torch.randn(B * T, V, requires_grad=True, device="cuda")
if log_target:
target_tensor = target_tensor.log_softmax(dim=-1)
else:
target_tensor = target_tensor.softmax(dim=-1)
target_tensor.retain_grad()

return input_tensor, target_tensor

# Test forward + backward pass
helion_kl = HelionKLDivLoss(reduction=reduction, log_target=log_target, eps=eps)
torch_kl_div = torch.nn.KLDivLoss(reduction="batchmean", log_target=log_target).to(
torch_kl_div = torch.nn.KLDivLoss(reduction=reduction, log_target=log_target).to(
"cuda"
)

Expand All @@ -200,7 +312,8 @@ def helion_wrapper(input_tensor: Tensor, target_tensor: Tensor) -> Tensor:
def baseline_wrapper(input_tensor: Tensor, target_tensor: Tensor) -> Tensor:
return torch_kl_div(input_tensor, target_tensor)

run_example(helion_wrapper, baseline_wrapper, (input_tensor, target_tensor))
run_example(helion_wrapper, baseline_wrapper, create_inputs())
run_example(helion_wrapper, baseline_wrapper, create_inputs(), bwd=True)


# %%
Expand Down Expand Up @@ -240,17 +353,17 @@ def main() -> None:
print("Testing KL divergence kernel...")
B = 8
T = 512
reduction = "batchmean"
log_target = False
eps = 1e-10

# Test with vocabulary sizes from tritonbench (2^12 to 2^17)
for V in [2**i for i in range(12, 18)]:
print(
f"Testing KL Div: B={B}, T={T}, V={V}, reduction={reduction}, log_target={log_target}"
)
check_kl_div_kernel(B, T, V, reduction, log_target, eps)
print("✓ KL Div passed")
for log_target in (True, False):
for reduction in ("batchmean", "mean", "sum"):
for V in [2**i for i in range(12, 17)]:
print(
f"Testing KL Div: B={B}, T={T}, V={V}, reduction={reduction}, log_target={log_target}"
)
check_kl_div_kernel(B, T, V, reduction, log_target, eps)
print("✓ KL Div passed")


# %%
Expand Down
58 changes: 58 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -2352,6 +2352,64 @@ def kl_div_forward(y_pred: Tensor, y_true: Tensor, log_target: bool=False, reduc
final_loss = loss
return final_loss

--- assertExpectedJournal(TestExamples.test_kl_div_bwd)
from __future__ import annotations

import torch
import helion.language as hl
import triton
import triton.language as tl
from torch._inductor.runtime.triton_helpers import math as tl_math
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_kl_div_backward(grad_out_expanded, y_true, grad_y_pred, y_pred, grad_y_true, grad_out_expanded_stride_0, grad_out_expanded_stride_1, grad_y_pred_stride_0, grad_y_pred_stride_1, grad_y_true_stride_0, grad_y_true_stride_1, y_pred_stride_0, y_pred_stride_1, y_true_stride_0, y_true_stride_1, BT, V, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < BT
for offset_1 in tl.range(0, V.to(tl.int32), _BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < V
grad_out_val = tl.load(grad_out_expanded + (indices_0[:, None] * grad_out_expanded_stride_0 + indices_1[None, :] * grad_out_expanded_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
y_true_val = tl.load(y_true + (indices_0[:, None] * y_true_stride_0 + indices_1[None, :] * y_true_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = -grad_out_val
v_1 = v_0 * y_true_val
v_2 = tl.cast(BT, tl.float32)
v_3 = v_1 / v_2
tl.store(grad_y_pred + (indices_0[:, None] * grad_y_pred_stride_0 + indices_1[None, :] * grad_y_pred_stride_1), v_3, mask_0[:, None] & mask_1[None, :])
y_pred_val = tl.load(y_pred + (indices_0[:, None] * y_pred_stride_0 + indices_1[None, :] * y_pred_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_4 = -23.025850929940457
v_5 = v_4 - y_pred_val
v_6 = tl_math.log(y_true_val)
v_7 = v_6 - y_pred_val
v_8 = 1.0
v_9 = v_7 + v_8
v_10 = 1e-10
v_11 = y_true_val < v_10
v_12 = tl.where(v_11, v_5, v_9)
v_13 = grad_out_val * v_12
v_14 = tl.cast(BT, tl.float32)
v_15 = v_13 / v_14
tl.store(grad_y_true + (indices_0[:, None] * grad_y_true_stride_0 + indices_1[None, :] * grad_y_true_stride_1), v_15, mask_0[:, None] & mask_1[None, :])

def kl_div_backward(grad_out: Tensor, y_pred: Tensor, y_true: Tensor, log_target: hl.constexpr=False, reduction: hl.constexpr='batchmean', eps: hl.constexpr=1e-10, compute_y_true_grad: hl.constexpr=True, *, _launcher=_default_launcher):
BT, V = y_pred.shape
assert y_true.shape == y_pred.shape, f'Shape mismatch: {y_true.shape} != {y_pred.shape}'
grad_y_pred = torch.empty_like(y_pred)
if True:
grad_y_true = torch.empty_like(y_true)
else:
grad_y_true = None
if 'batchmean' == 'none':
grad_out_expanded = grad_out
else:
grad_out_expanded = grad_out.expand(y_true.shape)
_BLOCK_SIZE_0 = 64
_BLOCK_SIZE_1 = 64
_launcher(_helion_kl_div_backward, (triton.cdiv(BT, _BLOCK_SIZE_0),), grad_out_expanded, y_true, grad_y_pred, y_pred, grad_y_true, grad_out_expanded.stride(0), grad_out_expanded.stride(1), grad_y_pred.stride(0), grad_y_pred.stride(1), grad_y_true.stride(0), grad_y_true.stride(1), y_pred.stride(0), y_pred.stride(1), y_true.stride(0), y_true.stride(1), BT, V, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return (grad_y_pred, grad_y_true)

--- assertExpectedJournal(TestExamples.test_layernorm_bwd_dwdb)
from __future__ import annotations

Expand Down
48 changes: 47 additions & 1 deletion test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ def test_jsd(self):
)
)

def test_kl_div(self):
def test_kl_div_fwd(self):
args = (
torch.randn(
[8 * 512, 4096], device=DEVICE, dtype=torch.float32
Expand All @@ -1146,6 +1146,52 @@ def test_kl_div(self):
)
)

def test_kl_div_bwd(self):
y_pred = torch.randn(
[8 * 512, 4096], device=DEVICE, dtype=torch.float32
).log_softmax(dim=-1)
y_true = torch.randn(
[8 * 512, 4096], device=DEVICE, dtype=torch.float32
).softmax(dim=-1)
grad_out = torch.randn([], device=DEVICE, dtype=torch.float32)
log_target = False
reduction = "batchmean"
eps = 1e-10

# Compute forward pass to get rms
from examples.kl_div import kl_div_forward

# Create configured kernel with explicit config
config = helion.Config(block_size=32, num_warps=4, num_stages=3)
configured_kernel = helion.kernel(kl_div_forward.fn, config=config)
_ = configured_kernel(y_pred, y_true, log_target, reduction, eps)

# Compute expected gradients with PyTorch
y_pred_torch = y_pred.detach().clone().requires_grad_(True)
y_true_torch = y_true.detach().clone().requires_grad_(True)
loss_torch = torch.nn.functional.kl_div(
y_pred_torch, y_true_torch, log_target=log_target, reduction=reduction
)
loss_torch.backward(grad_out)

args = (
grad_out,
y_pred,
y_true,
)

self.assertExpectedJournal(
check_example(
"kl_div",
args,
(y_pred_torch.grad, y_true_torch.grad),
fn_name="kl_div_backward",
block_sizes=[64, 64],
num_warps=4,
num_stages=3,
)
)

def test_gather_gemv(self):
args = (
torch.randn([8, 1024, 1024], device=DEVICE, dtype=torch.float32),
Expand Down