Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,12 @@ class RunResult:
"softmax": (
"tritonbench.operators.softmax.operator",
"examples.softmax",
"softmax",
"softmax_tritonbench",
),
"softmax-bwd": (
"tritonbench.operators.softmax.operator",
"examples.softmax",
"softmax_tritonbench",
),
"jagged_mean": (
"tritonbench.operators.jagged_mean.operator",
Expand Down Expand Up @@ -325,8 +330,17 @@ class RunResult:
"triton_softmax-accuracy": "triton_accuracy",
"torch_compile_softmax-speedup": "torch_compile_speedup",
"torch_compile_softmax-accuracy": "torch_compile_accuracy",
"helion_softmax-speedup": "helion_speedup",
"helion_softmax-accuracy": "helion_accuracy",
"helion_softmax_tritonbench-speedup": "helion_speedup",
"helion_softmax_tritonbench-accuracy": "helion_accuracy",
},
"softmax-bwd": {
"naive_softmax": "baseline",
"triton_softmax-speedup": "triton_speedup",
"triton_softmax-accuracy": "triton_accuracy",
"torch_compile_softmax-speedup": "torch_compile_speedup",
"torch_compile_softmax-accuracy": "torch_compile_accuracy",
"helion_softmax_tritonbench-speedup": "helion_speedup",
"helion_softmax_tritonbench-accuracy": "helion_accuracy",
},
"rms_norm": {
"llama_rms": "baseline",
Expand Down
87 changes: 87 additions & 0 deletions examples/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
# %%
from __future__ import annotations

from typing import Any
from typing import Callable

import torch

import helion
Expand Down Expand Up @@ -89,6 +92,79 @@ def softmax_two_pass(x: torch.Tensor) -> torch.Tensor:
return out


@helion.kernel()
def softmax_bwd(
grad_output: torch.Tensor, softmax_output: torch.Tensor
) -> torch.Tensor:
"""
Helion kernel implementing softmax backward pass.

dy/dx = softmax_output * (grad_output - sum(softmax_output * grad_output))

Args:
grad_output (torch.Tensor): Gradient from downstream layers of shape [m, n]
softmax_output (torch.Tensor): Output from forward softmax pass of shape [m, n]

Returns:
torch.Tensor: Gradient with respect to input of shape [m, n]
"""
m, n = grad_output.size()
grad_input = torch.empty_like(grad_output)

for tile_m in hl.tile(m):
sum_per_row = hl.zeros([tile_m], dtype=torch.float32)
for tile_n in hl.tile(n):
sum_per_row += torch.sum(
softmax_output[tile_m, tile_n] * grad_output[tile_m, tile_n], dim=1
)
for tile_n in hl.tile(n):
grad_input[tile_m, tile_n] = softmax_output[tile_m, tile_n] * (
grad_output[tile_m, tile_n] - sum_per_row[:, None]
)

return grad_input


class SoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any, # noqa: ANN401
x: torch.Tensor,
) -> torch.Tensor:
y = softmax_two_pass(x)
ctx.save_for_backward(y)
return y

@staticmethod
def backward( # type: ignore[override]
ctx: Any, # noqa: ANN401
grad_output: torch.Tensor,
) -> tuple[torch.Tensor | None]:
(softmax_output,) = ctx.saved_tensors
grad_x = softmax_bwd(grad_output, softmax_output)
return (grad_x,)


def softmax_fwd_bwd(
x: torch.Tensor,
) -> torch.Tensor:
"""Softmax with forward + backward support."""
return SoftmaxFunction.apply(x) # type: ignore[no-any-return]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you like to also add integration in benchmarks/run.py (similar to rms_norm-bwd), and test accuracy via tritonbench --metrics accuracy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backward pass for softmax is disabled in tritonbench. I enabled it in this PR meta-pytorch/tritonbench#528 and added softmax-bwd in benchmarks/run.py

Copy link
Contributor Author

@karthickai karthickai Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HELION_USE_DEFAULT_CONFIG=1 python benchmarks/run.py --op softmax-bwd --num-inputs 3 --metrics accuracy

     (M, N)    triton_softmax-accuracy    quack-accuracy    torch_compile_softmax-accuracy    helion_softmax_tritonbench-accuracy
-----------  -------------------------  ----------------  --------------------------------  -------------------------------------
(4096, 256)                          1                 1                                 1                                      1
(4096, 384)                          1                 1                                 1                                      1
(4096, 512)                          1                 1                                 1                                      1
    average                          1                 1                                 1                                      1
HELION_USE_DEFAULT_CONFIG=1 python benchmarks/run.py --op softmax --num-inputs 3 --metrics accuracy

     (M, N)    triton_softmax-accuracy    quack-accuracy    torch_compile_softmax-accuracy    helion_softmax_tritonbench-accuracy
-----------  -------------------------  ----------------  --------------------------------  -------------------------------------
(4096, 256)                          1                 1                                 1                                      1
(4096, 384)                          1                 1                                 1                                      1
(4096, 512)                          1                 1                                 1                                      1
    average                          1                 1                                 1                                      1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: tirtonbench softmax-bwd is landed meta-pytorch/tritonbench#528



def softmax_tritonbench(tb_op: object, x: torch.Tensor) -> Callable[[], torch.Tensor]:
"""
Wrapper for tritonbench that returns softmax with backward support.
Args:
tb_op: TritonBench operator instance
x: Input tensor

Returns:
Callable that returns the output tensor
"""
return lambda: softmax_fwd_bwd(x)


# %%
def check(m: int, n: int) -> None:
"""
Expand All @@ -105,6 +181,17 @@ def check(m: int, n: int) -> None:
}
run_example(kernels, lambda x: torch.nn.functional.softmax(x, dim=1), (x,))

print("\n\n=== Forward + Backward Pass Test ===")
x_grad = torch.randn([m, n], device="cuda", dtype=torch.float16, requires_grad=True)
run_example(
softmax_fwd_bwd,
torch.nn.functional.softmax,
(x_grad,),
rtol=1e-3,
atol=1e-3,
bwd=True,
)


# %%
def main() -> None:
Expand Down
62 changes: 62 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -3420,6 +3420,68 @@ def softmax(x: torch.Tensor, *, _launcher=_default_launcher):
_launcher(_helion_softmax, (n,), x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _m, _RDIM_SIZE_1, 1, num_warps=4, num_stages=1)
return out

--- assertExpectedJournal(TestExamples.test_softmax_bwd)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_softmax_bwd(softmax_output, grad_output, grad_input, grad_input_stride_0, grad_input_stride_1, grad_output_stride_0, grad_output_stride_1, softmax_output_stride_0, softmax_output_stride_1, m, n, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < m
sum_per_row = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
for offset_1 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < n
sum_per_row_copy = sum_per_row
sum_per_row_copy_0 = sum_per_row_copy
load = tl.load(softmax_output + (indices_0[:, None] * softmax_output_stride_0 + indices_1[None, :] * softmax_output_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
load_1 = tl.load(grad_output + (indices_0[:, None] * grad_output_stride_0 + indices_1[None, :] * grad_output_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = load * load_1
sum_1 = tl.cast(tl.sum(v_0, 1), tl.float16)
v_1 = tl.cast(sum_1, tl.float32)
sum_per_row = sum_per_row_copy_0 + v_1
for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_2):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
mask_2 = indices_2 < n
sum_per_row_copy_1 = sum_per_row
sum_per_row_copy_1_0 = sum_per_row_copy_1
load_2 = tl.load(softmax_output + (indices_0[:, None] * softmax_output_stride_0 + indices_2[None, :] * softmax_output_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
load_3 = tl.load(grad_output + (indices_0[:, None] * grad_output_stride_0 + indices_2[None, :] * grad_output_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
subscript = sum_per_row_copy_1_0[:, None]
v_3 = tl.cast(load_3, tl.float32)
v_4 = v_3 - subscript
v_5 = tl.cast(load_2, tl.float32)
v_6 = v_5 * v_4
v_7 = tl.cast(v_6, tl.float16)
tl.store(grad_input + (indices_0[:, None] * grad_input_stride_0 + indices_2[None, :] * grad_input_stride_1), v_7, mask_0[:, None] & mask_2[None, :])

def softmax_bwd(grad_output: torch.Tensor, softmax_output: torch.Tensor, *, _launcher=_default_launcher):
"""
Helion kernel implementing softmax backward pass.

dy/dx = softmax_output * (grad_output - sum(softmax_output * grad_output))

Args:
grad_output (torch.Tensor): Gradient from downstream layers of shape [m, n]
softmax_output (torch.Tensor): Output from forward softmax pass of shape [m, n]

Returns:
torch.Tensor: Gradient with respect to input of shape [m, n]
"""
m, n = grad_output.size()
grad_input = torch.empty_like(grad_output)
_BLOCK_SIZE_0 = 16
_BLOCK_SIZE_1 = 16
_BLOCK_SIZE_2 = 16
_launcher(_helion_softmax_bwd, (triton.cdiv(m, _BLOCK_SIZE_0),), softmax_output, grad_output, grad_input, grad_input.stride(0), grad_input.stride(1), grad_output.stride(0), grad_output.stride(1), softmax_output.stride(0), softmax_output.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
return grad_input

--- assertExpectedJournal(TestExamples.test_softmax_decomposed)
from __future__ import annotations

Expand Down
26 changes: 26 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,32 @@ def test_layernorm_bwd_dx(self):
)
)

def test_softmax_bwd(self):
m, n = 2048, 2048
x = torch.randn([m, n], device=DEVICE, dtype=torch.float16, requires_grad=True)
grad_out = torch.randn([m, n], device=DEVICE, dtype=torch.float16)

from examples.softmax import softmax_two_pass

config = helion.Config(block_size=[128, 128], num_warps=4, num_stages=3)
configured_kernel = helion.kernel(softmax_two_pass.fn, config=config)
y = configured_kernel(x)

x_torch = x.detach().clone().requires_grad_(True)
y_torch = torch.nn.functional.softmax(x_torch, dim=-1)
y_torch.backward(grad_out)

self.assertExpectedJournal(
check_example(
"softmax",
(grad_out, y),
x_torch.grad,
fn_name="softmax_bwd",
rtol=1e-3,
atol=1e-3,
)
)

def test_layernorm_without_bias(self):
x = -2.3 + 0.5 * torch.randn([32, 64], device=DEVICE, dtype=torch.float16)
weight = torch.randn([64], device=DEVICE, dtype=torch.float16)
Expand Down
Loading