Skip to content

Conversation

tianrengao
Copy link

@tianrengao tianrengao commented Oct 1, 2025

This PR adds backward pass (gradient computation) support for matrix multiplication and addmm operations in Helion, with comprehensive unit tests to ensure correctness against PyTorch baselines.

Key Changes

1. Backward Kernels

  • matmul_bwd: Computes gradients for matrix multiplication (grad_A = grad_C @ B.T, grad_B = A.T @ grad_C)
  • addmm_bwd: Computes gradients for addmm operation with alpha/beta scaling support

2. **PyTorch Autograd **

  • MatMulFunction & AddMMFunction: PyTorch autograd classes with proper *grad_outputs signatures
  • matmul_autograd & addmm_autograd: User-friendly API functions

3. Unit Tests

  • test_matmul_bwd: Validates matrix multiplication backward pass against PyTorch baseline
  • test_addmm_bwd: Validates addmm backward pass with gradient flow for all inputs

Usage Example

# Matrix multiplication with gradients
mat1 = torch.randn([128, 256], requires_grad=True, device='cuda')
mat2 = torch.randn([256, 128], requires_grad=True, device='cuda')
result = matmul_autograd(mat1, mat2)
result.sum().backward()  # Gradients available in mat1.grad, mat2.grad

# AddMM with scaling
bias = torch.randn([128, 128], requires_grad=True, device='cuda')
result = addmm_autograd(bias, mat1, mat2, alpha=2.0, beta=0.5)

Files Modified

  • examples/matmul.py: Added backward kernels, autograd classes, and enhanced testing
  • test/test_examples.py: Added test_matmul_bwd and test_addmm_bwd unit tests
  • test/test_examples.expected: Updated expected outputs for new tests

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 1, 2025
@tianrengao tianrengao changed the title add matmul bwd Add matmul/addmm bwd and add test coverage Oct 1, 2025
@tianrengao tianrengao changed the title Add matmul/addmm bwd and add test coverage Add matmul/addmm bwd examples and add test coverage Oct 1, 2025
@tianrengao tianrengao marked this pull request as ready for review October 2, 2025 00:22
@tianrengao tianrengao requested a review from yf225 October 2, 2025 02:07
m, k = mat1.size()
k2, n = mat2.size()
bias = torch.broadcast_to(bias, [m, n])
return lambda: matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you like to also add integration in benchmarks/run.py (similar to rms_norm-bwd), and test accuracy via tritonbench --metrics accuracy?

Also I believe these two *_tritonbench functions should probably just call the *_autograd functions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tritonbench PR for adding addm-bwd and gemm-bwd landed: meta-pytorch/tritonbench#531

All test passed except for gemm with partition-k. It seems partition-k in fwd is broken.

Addmm fwd

         (M, N, K)    triton_addmm-accuracy    pt2_triton_matmul-accuracy    helion_addmm_tritonbench-accuracy
------------------  -----------------------  ----------------------------  -----------------------------------
(20120, 512, 1536)                        1                             1                                    1
(34579, 512, 1536)                        1                             1                                    1
(34839, 512, 1536)                        1                             1                                    1
           average                        1                             1                                    1

Addmm bwd

         (M, N, K)    triton_addmm-accuracy    pt2_triton_matmul-accuracy    helion_addmm_tritonbench-accuracy
------------------  -----------------------  ----------------------------  -----------------------------------
(20120, 512, 1536)                        1                             1                                    1
(34579, 512, 1536)                        1                             1                                    1
(34839, 512, 1536)                        1                             1                                    1
           average                        1                             1                                    1

gemm fwd

      (M, N, K)    triton_tutorial_matmul-accuracy    matmul_partition_k-accuracy    triton_ops_matmul-accuracy    aten_tunableop_matmul-accuracy    pt2_triton_matmul-accuracy    streamk_matmul-accuracy    pt2_cutlass_matmul-accuracy    helion_matmul_tritonbench-accuracy
---------------  ---------------------------------  -----------------------------  ----------------------------  --------------------------------  ----------------------------  -------------------------  -----------------------------  ------------------------------------
(256, 256, 256)                                  1                       0                                    1                                 1                             1                          1                              1                                     1
(384, 384, 384)                                  1                       0                                    1                                 1                             1                          1                              1                                     1
(512, 512, 512)                                  1                       1                                    1                                 1                             1                          1                              1                                     1
        average                                  1                       0.333333                             1                                 1                             1                          1                              1                                     1

gemm bwd

  (M, N, K)    triton_tutorial_matmul-accuracy    matmul_partition_k-accuracy    triton_ops_matmul-accuracy    aten_tunableop_matmul-accuracy    pt2_triton_matmul-accuracy    streamk_matmul-accuracy    pt2_cutlass_matmul-accuracy    helion_matmul_tritonbench-accuracy
---------------  ---------------------------------  -----------------------------  ----------------------------  --------------------------------  ----------------------------  -------------------------  -----------------------------  ------------------------------------
(256, 256, 256)                                  1                              0                             1                                 1                             1                          1                              1                                     1
(384, 384, 384)                                  1                              0                             1                                 1                             1                          1                              1                                     1
(512, 512, 512)                                  1                              0                             1                                 1                             1                          1                              1                                     1
        average                                  1                              0                             1                                 1                             1                          1                              1                                     1

- Fixed matmul_tritonbench to use addmm_autograd for gemm with bias testing
- Updated tritonbench operators for proper requires_grad handling
- gemm-bwd showing 100% accuracy for Helion implementation
- Some gradient mismatches on larger shapes still being investigated

Current test results:
- addmm-bwd: 100% accuracy
- gemm-bwd: 100% accuracy for Helion, some framework gradient issues remain
@tianrengao
Copy link
Author

The test failure seems unrelated to my PR. It complains about illegal memory access for 500+ tests. I saw similar failures in other PRs too.

@tianrengao tianrengao requested a review from yf225 October 9, 2025 22:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants