-
Notifications
You must be signed in to change notification settings - Fork 578
refactor: backend_requirement + supported_compute_capability decorator for gemm #2000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -45,7 +45,12 @@ | |||||||
| from .cuda_utils import checkCudaErrors | ||||||||
| from .jit.cubin_loader import get_cubin | ||||||||
| from .jit.env import FLASHINFER_CUBIN_DIR | ||||||||
| from .utils import ceil_div, round_up | ||||||||
| from .utils import ( | ||||||||
| ceil_div, | ||||||||
| round_up, | ||||||||
| supported_compute_capability, | ||||||||
| backend_requirement, | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
||||||||
| class GemmType(enum.Enum): | ||||||||
|
|
@@ -1358,24 +1363,27 @@ def m_grouped_fp8_gemm_nt_masked_sm10x( | |||||||
| runtime(**all_kwargs) | ||||||||
|
|
||||||||
|
|
||||||||
| def m_grouped_fp8_gemm_nt_contiguous( | ||||||||
| @supported_compute_capability([100, 103]) | ||||||||
| def _check_group_deepgemm_fp8_nt_contiguous_problem_size( | ||||||||
| a_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
| b_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
| d: torch.Tensor, | ||||||||
| m_indices: torch.Tensor, | ||||||||
| recipe: Optional[Tuple[int, int, int]] = None, | ||||||||
| compiled_dims: str = "nk", | ||||||||
| ) -> None: | ||||||||
| # Compiled dims can be upper cases | ||||||||
| compiled_dims = compiled_dims.lower() | ||||||||
|
|
||||||||
| ) -> bool: | ||||||||
| # NOTES: shape must be `[M, K] @ [G, N, K].mT` | ||||||||
| major_a = get_major_type_ab(a_fp8[0]) | ||||||||
| major_b = get_major_type_ab(b_fp8[0]) | ||||||||
| assert major_a == MajorTypeAB.KMajor | ||||||||
| if must_be_k_major(): | ||||||||
| assert major_b == MajorTypeAB.KMajor | ||||||||
| assert m_indices.is_contiguous() | ||||||||
| if major_a != MajorTypeAB.KMajor: | ||||||||
| raise ValueError(f"major_a must be KMajor, but got {major_a}") | ||||||||
| if must_be_k_major() and (major_b != MajorTypeAB.KMajor): | ||||||||
| raise ValueError(f"major_b must be KMajor, but got {major_b}") | ||||||||
|
|
||||||||
| if not m_indices.is_contiguous(): | ||||||||
| raise ValueError( | ||||||||
| f"m_indices must be contiguous, but got {m_indices.is_contiguous()}" | ||||||||
| ) | ||||||||
|
|
||||||||
| a, sfa = a_fp8 | ||||||||
| b, sfb = b_fp8 | ||||||||
|
|
@@ -1385,15 +1393,48 @@ def m_grouped_fp8_gemm_nt_contiguous( | |||||||
| m__ = m_indices.numel() | ||||||||
|
|
||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The check for positive dimensions
Suggested change
|
||||||||
| # Type and shape checks | ||||||||
| assert m == m_ == m__ and n == n_ and k == k_ | ||||||||
| assert n > 0 and k > 0 and num_groups > 0 | ||||||||
| assert a.dtype == torch.float8_e4m3fn | ||||||||
| assert b.dtype == torch.float8_e4m3fn | ||||||||
| assert d.dtype == torch.bfloat16 | ||||||||
| assert m_indices.dtype == torch.int32 | ||||||||
| if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__: | ||||||||
| raise ValueError( | ||||||||
| f"Shape mismatch. m = {m}, m_ = {m_}, k = {k}, k_ = {k_}, n = {n}, n_ = {n_}, m__ = {m__}" | ||||||||
| ) | ||||||||
| if a.dtype != torch.float8_e4m3fn: | ||||||||
| raise ValueError(f"a must be float8_e4m3fn, but got {a.dtype}") | ||||||||
| if b.dtype != torch.float8_e4m3fn: | ||||||||
| raise ValueError(f"b must be float8_e4m3fn, but got {b.dtype}") | ||||||||
| if d.dtype != torch.bfloat16: | ||||||||
| raise ValueError(f"d must be bfloat16, but got {d.dtype}") | ||||||||
| if m_indices.dtype != torch.int32: | ||||||||
| raise ValueError(f"m_indices must be int32, but got {m_indices.dtype}") | ||||||||
|
|
||||||||
| # D must be N-major | ||||||||
| assert get_major_type_cd(d) == MajorTypeCD.NMajor | ||||||||
| if get_major_type_cd(d) != MajorTypeCD.NMajor: | ||||||||
| raise ValueError(f"d must be N-major, but got {get_major_type_cd(d)}") | ||||||||
|
|
||||||||
| return True | ||||||||
|
|
||||||||
|
|
||||||||
| @backend_requirement( | ||||||||
| {}, | ||||||||
| common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size, | ||||||||
| ) | ||||||||
| def m_grouped_fp8_gemm_nt_contiguous( | ||||||||
| a_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
| b_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
| d: torch.Tensor, | ||||||||
| m_indices: torch.Tensor, | ||||||||
| recipe: Optional[Tuple[int, int, int]] = None, | ||||||||
| compiled_dims: str = "nk", | ||||||||
| ) -> None: | ||||||||
| # Compiled dims can be upper cases | ||||||||
| compiled_dims = compiled_dims.lower() | ||||||||
|
|
||||||||
| major_a = get_major_type_ab(a_fp8[0]) | ||||||||
| major_b = get_major_type_ab(b_fp8[0]) | ||||||||
|
|
||||||||
| a, sfa = a_fp8 | ||||||||
| b, sfb = b_fp8 | ||||||||
| m, k = a.shape | ||||||||
| num_groups, n, k_ = b.shape | ||||||||
|
|
||||||||
| # Do nothing if the problem is empty | ||||||||
| if m == 0: | ||||||||
|
|
@@ -1423,6 +1464,72 @@ def m_grouped_fp8_gemm_nt_contiguous( | |||||||
| impl(a, sfa, b, sfb, d, m_indices) | ||||||||
|
|
||||||||
|
|
||||||||
| @supported_compute_capability([100, 103]) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here, the compute capabilities will be ignored. |
||||||||
| def _check_m_grouped_fp8_gemm_nt_masked_problem_size( | ||||||||
| a_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
| b_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
| d: torch.Tensor, | ||||||||
| masked_m: torch.Tensor, | ||||||||
| expected_m: int, | ||||||||
| recipe: Optional[Tuple[int, int, int]] = None, | ||||||||
| compiled_dims: str = "nk", | ||||||||
| ) -> bool: | ||||||||
| major_a = get_major_type_ab(a_fp8[0]) | ||||||||
| major_b = get_major_type_ab(b_fp8[0]) | ||||||||
| if major_a != MajorTypeAB.KMajor: | ||||||||
| raise ValueError(f"major_a must be KMajor, but got {major_a}") | ||||||||
| if major_b != MajorTypeAB.KMajor: | ||||||||
| raise ValueError(f"major_b must be KMajor, but got {major_b}") | ||||||||
|
|
||||||||
| if not masked_m.is_contiguous(): | ||||||||
| raise ValueError( | ||||||||
| f"masked_m must be contiguous, but got {masked_m.is_contiguous()}" | ||||||||
| ) | ||||||||
|
|
||||||||
| a, sfa = a_fp8 | ||||||||
| b, sfb = b_fp8 | ||||||||
| num_groups, m, k = a.shape | ||||||||
| num_groups_, n, k_ = b.shape | ||||||||
| num_groups__, m_, n_ = d.shape | ||||||||
| num_groups___ = masked_m.numel() | ||||||||
|
|
||||||||
| # Type and shape checks | ||||||||
| if ( | ||||||||
| num_groups != num_groups_ | ||||||||
| or num_groups != num_groups__ | ||||||||
| or num_groups != num_groups___ | ||||||||
| ): | ||||||||
| raise ValueError( | ||||||||
| f"num_groups mismatch. num_groups = {num_groups}, num_groups_ = {num_groups_}, num_groups__ = {num_groups__}, num_groups___ = {num_groups___}" | ||||||||
| ) | ||||||||
| if m != m_ or n != n_ or k != k_: | ||||||||
| raise ValueError( | ||||||||
| f"m, n, k mismatch. m = {m}, m_ = {m_}, n = {n}, n_ = {n_}, k = {k}, k_ = {k_}" | ||||||||
| ) | ||||||||
| if expected_m <= 0 or m <= 0 or n <= 0 or k <= 0 or num_groups <= 0: | ||||||||
| raise ValueError( | ||||||||
| f"expected_m, m, n, k, num_groups must be greater than 0, but got expected_m = {expected_m}, m = {m}, n = {n}, k = {k}, num_groups = {num_groups}" | ||||||||
| ) | ||||||||
| if a.dtype != torch.float8_e4m3fn: | ||||||||
| raise ValueError(f"a must be float8_e4m3fn, but got {a.dtype}") | ||||||||
| if b.dtype != torch.float8_e4m3fn: | ||||||||
| raise ValueError(f"b must be float8_e4m3fn, but got {b.dtype}") | ||||||||
| if d.dtype != torch.bfloat16: | ||||||||
| raise ValueError(f"d must be bfloat16, but got {d.dtype}") | ||||||||
| if masked_m.dtype != torch.int32: | ||||||||
| raise ValueError(f"masked_m must be int32, but got {masked_m.dtype}") | ||||||||
|
|
||||||||
| # D must be N-major | ||||||||
| if get_major_type_cd(d) != MajorTypeCD.NMajor: | ||||||||
| raise ValueError(f"d must be N-major, but got {get_major_type_cd(d)}") | ||||||||
|
|
||||||||
| return True | ||||||||
|
|
||||||||
|
|
||||||||
| @backend_requirement( | ||||||||
| {}, | ||||||||
| common_check=_check_m_grouped_fp8_gemm_nt_masked_problem_size, | ||||||||
| ) | ||||||||
| def m_grouped_fp8_gemm_nt_masked( | ||||||||
| a_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
| b_fp8: Tuple[torch.Tensor, torch.Tensor], | ||||||||
|
|
@@ -1445,20 +1552,6 @@ def m_grouped_fp8_gemm_nt_masked( | |||||||
| b, sfb = b_fp8 | ||||||||
| num_groups, m, k = a.shape | ||||||||
| num_groups_, n, k_ = b.shape | ||||||||
| num_groups__, m_, n_ = d.shape | ||||||||
| num_groups___ = masked_m.numel() | ||||||||
|
|
||||||||
| # Type and shape checks | ||||||||
| assert num_groups == num_groups_ == num_groups__ == num_groups___ | ||||||||
| assert m == m_ and n == n_ and k == k_ | ||||||||
| assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0 | ||||||||
| assert a.dtype == torch.float8_e4m3fn | ||||||||
| assert b.dtype == torch.float8_e4m3fn | ||||||||
| assert d.dtype == torch.bfloat16 | ||||||||
| assert masked_m.dtype == torch.int32 | ||||||||
|
|
||||||||
| # D must be N-major | ||||||||
| assert get_major_type_cd(d) == MajorTypeCD.NMajor | ||||||||
|
|
||||||||
| # Transform SFA and SFB into compute-required layout | ||||||||
| recipe = get_default_recipe(sfa.dtype, sfb.dtype) if recipe is None else recipe | ||||||||
|
|
||||||||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, the compute capability will not be checked actually on the common_check function.
We would either need to:
My preference would go to option 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 2 also makes most sense. In a lot of the APIs there are also no 'backend' arg to be passed in so we can't only check @ supported_compute_capability there. I can change this in a separate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A separate PR is fine, since we wouldn't cause a regression (we didn't have CC checks before the current PR anyway).