-
Notifications
You must be signed in to change notification settings - Fork 36
Add matmul/addmm bwd examples and add test coverage #748
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
a7958ba
to
0960861
Compare
examples/matmul.py
Outdated
m, k = mat1.size() | ||
k2, n = mat2.size() | ||
bias = torch.broadcast_to(bias, [m, n]) | ||
return lambda: matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you like to also add integration in benchmarks/run.py
(similar to rms_norm-bwd
), and test accuracy via tritonbench --metrics accuracy
?
Also I believe these two *_tritonbench
functions should probably just call the *_autograd
functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tritonbench PR for adding addm-bwd and gemm-bwd landed: meta-pytorch/tritonbench#531
All test passed except for gemm with partition-k. It seems partition-k in fwd is broken.
Addmm fwd
(M, N, K) triton_addmm-accuracy pt2_triton_matmul-accuracy helion_addmm_tritonbench-accuracy
------------------ ----------------------- ---------------------------- -----------------------------------
(20120, 512, 1536) 1 1 1
(34579, 512, 1536) 1 1 1
(34839, 512, 1536) 1 1 1
average 1 1 1
Addmm bwd
(M, N, K) triton_addmm-accuracy pt2_triton_matmul-accuracy helion_addmm_tritonbench-accuracy
------------------ ----------------------- ---------------------------- -----------------------------------
(20120, 512, 1536) 1 1 1
(34579, 512, 1536) 1 1 1
(34839, 512, 1536) 1 1 1
average 1 1 1
gemm fwd
(M, N, K) triton_tutorial_matmul-accuracy matmul_partition_k-accuracy triton_ops_matmul-accuracy aten_tunableop_matmul-accuracy pt2_triton_matmul-accuracy streamk_matmul-accuracy pt2_cutlass_matmul-accuracy helion_matmul_tritonbench-accuracy
--------------- --------------------------------- ----------------------------- ---------------------------- -------------------------------- ---------------------------- ------------------------- ----------------------------- ------------------------------------
(256, 256, 256) 1 0 1 1 1 1 1 1
(384, 384, 384) 1 0 1 1 1 1 1 1
(512, 512, 512) 1 1 1 1 1 1 1 1
average 1 0.333333 1 1 1 1 1 1
gemm bwd
(M, N, K) triton_tutorial_matmul-accuracy matmul_partition_k-accuracy triton_ops_matmul-accuracy aten_tunableop_matmul-accuracy pt2_triton_matmul-accuracy streamk_matmul-accuracy pt2_cutlass_matmul-accuracy helion_matmul_tritonbench-accuracy
--------------- --------------------------------- ----------------------------- ---------------------------- -------------------------------- ---------------------------- ------------------------- ----------------------------- ------------------------------------
(256, 256, 256) 1 0 1 1 1 1 1 1
(384, 384, 384) 1 0 1 1 1 1 1 1
(512, 512, 512) 1 0 1 1 1 1 1 1
average 1 0 1 1 1 1 1 1
- Fixed matmul_tritonbench to use addmm_autograd for gemm with bias testing - Updated tritonbench operators for proper requires_grad handling - gemm-bwd showing 100% accuracy for Helion implementation - Some gradient mismatches on larger shapes still being investigated Current test results: - addmm-bwd: 100% accuracy - gemm-bwd: 100% accuracy for Helion, some framework gradient issues remain
…leop_results0.csv, PR_748_DOCUMENTATION.md)
04eb760
to
ee4f1d9
Compare
The test failure seems unrelated to my PR. It complains about illegal memory access for 500+ tests. I saw similar failures in other PRs too. |
This PR adds backward pass (gradient computation) support for matrix multiplication and addmm operations in Helion, with comprehensive unit tests to ensure correctness against PyTorch baselines.
Key Changes
1. Backward Kernels
matmul_bwd
: Computes gradients for matrix multiplication (grad_A = grad_C @ B.T
,grad_B = A.T @ grad_C
)addmm_bwd
: Computes gradients for addmm operation with alpha/beta scaling support2. **PyTorch Autograd **
MatMulFunction
&AddMMFunction
: PyTorch autograd classes with proper*grad_outputs
signaturesmatmul_autograd
&addmm_autograd
: User-friendly API functions3. Unit Tests
test_matmul_bwd
: Validates matrix multiplication backward pass against PyTorch baselinetest_addmm_bwd
: Validates addmm backward pass with gradient flow for all inputsUsage Example
Files Modified
examples/matmul.py
: Added backward kernels, autograd classes, and enhanced testingtest/test_examples.py
: Addedtest_matmul_bwd
andtest_addmm_bwd
unit teststest/test_examples.expected
: Updated expected outputs for new tests