diff --git a/benchmarks/run.py b/benchmarks/run.py index 750833422..af7e396ec 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -534,6 +534,15 @@ class RunResult: "helion_addmm_tritonbench-speedup": "helion_speedup", "helion_addmm_tritonbench-accuracy": "helion_accuracy", }, + "addmm-bwd": { + "aten_addmm": "baseline", + "triton_addmm-speedup": "triton_speedup", + "triton_addmm-accuracy": "triton_accuracy", + "pt2_addmm_maxautotune-speedup": "torch_compile_speedup", + "pt2_addmm_maxautotune-accuracy": "torch_compile_accuracy", + "helion_addmm_tritonbench-speedup": "helion_speedup", + "helion_addmm_tritonbench-accuracy": "helion_accuracy", + }, # "ragged_attention": { # "triton_ragged_attention-speedup": "triton_speedup", # "triton_ragged_attention-accuracy": "triton_accuracy", @@ -603,6 +612,15 @@ class RunResult: "helion_matmul_tritonbench-speedup": "helion_speedup", "helion_matmul_tritonbench-accuracy": "helion_accuracy", }, + "gemm-bwd": { + "aten_matmul": "baseline", + "triton_tutorial_matmul-speedup": "triton_speedup", + "triton_tutorial_matmul-accuracy": "triton_accuracy", + "pt2_triton_matmul-speedup": "torch_compile_speedup", + "pt2_triton_matmul-accuracy": "torch_compile_accuracy", + "helion_matmul_tritonbench-speedup": "helion_speedup", + "helion_matmul_tritonbench-accuracy": "helion_accuracy", + }, "fp8_gemm": { "torch_fp8_gemm": "baseline", f"{'blackwell_persistent_tma' if IS_B200 else 'triton_tma_persistent'}_fp8_gemm-speedup": "triton_speedup", diff --git a/examples/matmul.py b/examples/matmul.py index b3c3ca4d8..b799bf04e 100644 --- a/examples/matmul.py +++ b/examples/matmul.py @@ -64,6 +64,116 @@ def matmul( return out +# %% +class MatMulFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, # noqa: ANN401 + mat1: Tensor, + mat2: Tensor, + ) -> Tensor: + """Forward pass for matrix multiplication.""" + result = matmul(mat1, mat2) + ctx.save_for_backward(mat1, mat2) + return result + + @staticmethod + def backward( + ctx: Any, # noqa: ANN401 + *grad_outputs: Tensor, + ) -> tuple[Tensor | None, Tensor | None]: + """ + Backward pass for matrix multiplication. + + For C = A @ B, given grad_C: + - grad_A = grad_C @ B.T + - grad_B = A.T @ grad_C + + We reuse the forward matmul kernel for both computations. + """ + grad_out = grad_outputs[0] + mat1, mat2 = ctx.saved_tensors + + # grad_mat1 = grad_out @ mat2.T + grad_mat1 = matmul(grad_out, mat2.T) + + # grad_mat2 = mat1.T @ grad_out + grad_mat2 = matmul(mat1.T, grad_out) + + return grad_mat1, grad_mat2 + + +def matmul_autograd(mat1: Tensor, mat2: Tensor) -> Tensor: + """Matrix multiplication with forward + backward support.""" + return MatMulFunction.apply(mat1, mat2) # type: ignore[no-any-return] + + +class AddMMFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, # noqa: ANN401 + bias: Tensor, + mat1: Tensor, + mat2: Tensor, + alpha: float = 1.0, + beta: float = 1.0, + ) -> Tensor: + """Forward pass for addmm operation using helion matmul with epilogue.""" + m, k = mat1.size() + k2, n = mat2.size() + input_broadcasted = torch.broadcast_to(bias, [m, n]) + + # Define epilogue that adds bias: alpha * acc + beta * bias + def addmm_epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor: + return alpha * acc + beta * input_broadcasted[tile[0], tile[1]] + + result = matmul(mat1, mat2, addmm_epilogue) + ctx.save_for_backward(bias, mat1, mat2) + ctx.alpha = alpha + ctx.beta = beta + return result + + @staticmethod + def backward( + ctx: Any, # noqa: ANN401 + *grad_outputs: Tensor, + ) -> tuple[Tensor | None, Tensor | None, Tensor | None, None, None]: + """ + Backward pass for addmm operation. + + Forward: output = beta * bias + alpha * (mat1 @ mat2) + + Given grad_out: + - grad_bias = beta * grad_out + - grad_mat1 = alpha * (grad_out @ mat2.T) + - grad_mat2 = alpha * (mat1.T @ grad_out) + + We reuse the forward matmul kernel for both matrix gradient computations. + """ + grad_out = grad_outputs[0] + bias, mat1, mat2 = ctx.saved_tensors + alpha = ctx.alpha + beta = ctx.beta + + # grad_bias = beta * grad_out + grad_bias = beta * grad_out + + # grad_mat1 = alpha * (grad_out @ mat2.T) + grad_mat1 = alpha * matmul(grad_out, mat2.T) + + # grad_mat2 = alpha * (mat1.T @ grad_out) + grad_mat2 = alpha * matmul(mat1.T, grad_out) + + return grad_bias, grad_mat1, grad_mat2, None, None + + +def addmm_autograd( + bias: Tensor, mat1: Tensor, mat2: Tensor, alpha: float = 1.0, beta: float = 1.0 +) -> Tensor: + """AddMM operation with forward + backward support.""" + return AddMMFunction.apply(bias, mat1, mat2, alpha, beta) # type: ignore[no-any-return] + + @helion.kernel def matmul_bwd( grad_out: Tensor, # [m, n] gradient w.r.t output @@ -188,84 +298,6 @@ def addmm_bwd( return grad_input, grad_mat1, grad_mat2 -# %% -class MatMulFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, # noqa: ANN401 - mat1: Tensor, - mat2: Tensor, - ) -> Tensor: - """Forward pass for matrix multiplication.""" - result = matmul(mat1, mat2) - ctx.save_for_backward(mat1, mat2) - return result - - @staticmethod - def backward( - ctx: Any, # noqa: ANN401 - *grad_outputs: Tensor, - ) -> tuple[Tensor | None, Tensor | None]: - """Backward pass for matrix multiplication.""" - grad_out = grad_outputs[0] - mat1, mat2 = ctx.saved_tensors - grad_mat1, grad_mat2 = matmul_bwd(grad_out, mat1, mat2) - return grad_mat1, grad_mat2 - - -def matmul_autograd(mat1: Tensor, mat2: Tensor) -> Tensor: - """Matrix multiplication with forward + backward support.""" - return MatMulFunction.apply(mat1, mat2) # type: ignore[no-any-return] - - -class AddMMFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, # noqa: ANN401 - bias: Tensor, - mat1: Tensor, - mat2: Tensor, - alpha: float = 1.0, - beta: float = 1.0, - ) -> Tensor: - """Forward pass for addmm operation using helion matmul with epilogue.""" - m, k = mat1.size() - k2, n = mat2.size() - input_broadcasted = torch.broadcast_to(bias, [m, n]) - - # Define epilogue that adds bias: alpha * acc + beta * bias - def addmm_epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor: - return alpha * acc + beta * input_broadcasted[tile[0], tile[1]] - - result = matmul(mat1, mat2, addmm_epilogue) - ctx.save_for_backward(bias, mat1, mat2) - ctx.alpha = alpha - ctx.beta = beta - return result - - @staticmethod - def backward( - ctx: Any, # noqa: ANN401 - *grad_outputs: Tensor, - ) -> tuple[Tensor | None, Tensor | None, Tensor | None, None, None]: - """Backward pass for addmm operation.""" - grad_out = grad_outputs[0] - bias, mat1, mat2 = ctx.saved_tensors - alpha = ctx.alpha - beta = ctx.beta - grad_input, grad_mat1, grad_mat2 = addmm_bwd( - grad_out, bias, mat1, mat2, alpha, beta - ) - return grad_input, grad_mat1, grad_mat2, None, None - - -def addmm_autograd( - bias: Tensor, mat1: Tensor, mat2: Tensor, alpha: float = 1.0, beta: float = 1.0 -) -> Tensor: - """AddMM operation with forward + backward support.""" - return AddMMFunction.apply(bias, mat1, mat2, alpha, beta) # type: ignore[no-any-return] - - # %% def autotune(m: int, k: int, n: int) -> None: """