Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 124 additions & 31 deletions flashinfer/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
from .cuda_utils import checkCudaErrors
from .jit.cubin_loader import get_cubin
from .jit.env import FLASHINFER_CUBIN_DIR
from .utils import ceil_div, round_up
from .utils import (
ceil_div,
round_up,
supported_compute_capability,
backend_requirement,
)


class GemmType(enum.Enum):
Expand Down Expand Up @@ -1358,24 +1363,27 @@ def m_grouped_fp8_gemm_nt_masked_sm10x(
runtime(**all_kwargs)


def m_grouped_fp8_gemm_nt_contiguous(
@supported_compute_capability([100, 103])
Copy link
Contributor

@nvmbreughe nvmbreughe Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the compute capability will not be checked actually on the common_check function.
We would either need to:

  1. Redesign the decorator so that we can just add @supported_compute_capability to m_grouped_fp8_gemm_nt_contiguous directly; OR
  2. Change the decorator's implementations of wrapper and is_compute_capability_supported so it also checks this on the common_check.
  3. Add a backend parameter to m_grouped_fp8_gemm_nt_contiguous. Though this is deepgemm and I don't think we want to call this a backend for now.

My preference would go to option 2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 2 also makes most sense. In a lot of the APIs there are also no 'backend' arg to be passed in so we can't only check @ supported_compute_capability there. I can change this in a separate PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A separate PR is fine, since we wouldn't cause a regression (we didn't have CC checks before the current PR anyway).

def _check_group_deepgemm_fp8_nt_contiguous_problem_size(
a_fp8: Tuple[torch.Tensor, torch.Tensor],
b_fp8: Tuple[torch.Tensor, torch.Tensor],
d: torch.Tensor,
m_indices: torch.Tensor,
recipe: Optional[Tuple[int, int, int]] = None,
compiled_dims: str = "nk",
) -> None:
# Compiled dims can be upper cases
compiled_dims = compiled_dims.lower()

) -> bool:
# NOTES: shape must be `[M, K] @ [G, N, K].mT`
major_a = get_major_type_ab(a_fp8[0])
major_b = get_major_type_ab(b_fp8[0])
assert major_a == MajorTypeAB.KMajor
if must_be_k_major():
assert major_b == MajorTypeAB.KMajor
assert m_indices.is_contiguous()
if major_a != MajorTypeAB.KMajor:
raise ValueError(f"major_a must be KMajor, but got {major_a}")
if must_be_k_major() and (major_b != MajorTypeAB.KMajor):
raise ValueError(f"major_b must be KMajor, but got {major_b}")

if not m_indices.is_contiguous():
raise ValueError(
f"m_indices must be contiguous, but got {m_indices.is_contiguous()}"
)

a, sfa = a_fp8
b, sfb = b_fp8
Expand All @@ -1385,15 +1393,48 @@ def m_grouped_fp8_gemm_nt_contiguous(
m__ = m_indices.numel()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check for positive dimensions n, k, and num_groups is missing. The original code had assert n > 0 and k > 0 and num_groups > 0. This check is important to prevent unexpected behavior with empty or invalid dimensions.

Suggested change
if n <= 0 or k <= 0 or num_groups <= 0:
raise ValueError(f"n, k, and num_groups must be positive, but got n={n}, k={k}, num_groups={num_groups}")

# Type and shape checks
assert m == m_ == m__ and n == n_ and k == k_
assert n > 0 and k > 0 and num_groups > 0
assert a.dtype == torch.float8_e4m3fn
assert b.dtype == torch.float8_e4m3fn
assert d.dtype == torch.bfloat16
assert m_indices.dtype == torch.int32
if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__:
raise ValueError(
f"Shape mismatch. m = {m}, m_ = {m_}, k = {k}, k_ = {k_}, n = {n}, n_ = {n_}, m__ = {m__}"
)
if a.dtype != torch.float8_e4m3fn:
raise ValueError(f"a must be float8_e4m3fn, but got {a.dtype}")
if b.dtype != torch.float8_e4m3fn:
raise ValueError(f"b must be float8_e4m3fn, but got {b.dtype}")
if d.dtype != torch.bfloat16:
raise ValueError(f"d must be bfloat16, but got {d.dtype}")
if m_indices.dtype != torch.int32:
raise ValueError(f"m_indices must be int32, but got {m_indices.dtype}")

# D must be N-major
assert get_major_type_cd(d) == MajorTypeCD.NMajor
if get_major_type_cd(d) != MajorTypeCD.NMajor:
raise ValueError(f"d must be N-major, but got {get_major_type_cd(d)}")

return True


@backend_requirement(
{},
common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size,
)
def m_grouped_fp8_gemm_nt_contiguous(
a_fp8: Tuple[torch.Tensor, torch.Tensor],
b_fp8: Tuple[torch.Tensor, torch.Tensor],
d: torch.Tensor,
m_indices: torch.Tensor,
recipe: Optional[Tuple[int, int, int]] = None,
compiled_dims: str = "nk",
) -> None:
# Compiled dims can be upper cases
compiled_dims = compiled_dims.lower()

major_a = get_major_type_ab(a_fp8[0])
major_b = get_major_type_ab(b_fp8[0])

a, sfa = a_fp8
b, sfb = b_fp8
m, k = a.shape
num_groups, n, k_ = b.shape

# Do nothing if the problem is empty
if m == 0:
Expand Down Expand Up @@ -1423,6 +1464,72 @@ def m_grouped_fp8_gemm_nt_contiguous(
impl(a, sfa, b, sfb, d, m_indices)


@supported_compute_capability([100, 103])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, the compute capabilities will be ignored.

def _check_m_grouped_fp8_gemm_nt_masked_problem_size(
a_fp8: Tuple[torch.Tensor, torch.Tensor],
b_fp8: Tuple[torch.Tensor, torch.Tensor],
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
recipe: Optional[Tuple[int, int, int]] = None,
compiled_dims: str = "nk",
) -> bool:
major_a = get_major_type_ab(a_fp8[0])
major_b = get_major_type_ab(b_fp8[0])
if major_a != MajorTypeAB.KMajor:
raise ValueError(f"major_a must be KMajor, but got {major_a}")
if major_b != MajorTypeAB.KMajor:
raise ValueError(f"major_b must be KMajor, but got {major_b}")

if not masked_m.is_contiguous():
raise ValueError(
f"masked_m must be contiguous, but got {masked_m.is_contiguous()}"
)

a, sfa = a_fp8
b, sfb = b_fp8
num_groups, m, k = a.shape
num_groups_, n, k_ = b.shape
num_groups__, m_, n_ = d.shape
num_groups___ = masked_m.numel()

# Type and shape checks
if (
num_groups != num_groups_
or num_groups != num_groups__
or num_groups != num_groups___
):
raise ValueError(
f"num_groups mismatch. num_groups = {num_groups}, num_groups_ = {num_groups_}, num_groups__ = {num_groups__}, num_groups___ = {num_groups___}"
)
if m != m_ or n != n_ or k != k_:
raise ValueError(
f"m, n, k mismatch. m = {m}, m_ = {m_}, n = {n}, n_ = {n_}, k = {k}, k_ = {k_}"
)
if expected_m <= 0 or m <= 0 or n <= 0 or k <= 0 or num_groups <= 0:
raise ValueError(
f"expected_m, m, n, k, num_groups must be greater than 0, but got expected_m = {expected_m}, m = {m}, n = {n}, k = {k}, num_groups = {num_groups}"
)
if a.dtype != torch.float8_e4m3fn:
raise ValueError(f"a must be float8_e4m3fn, but got {a.dtype}")
if b.dtype != torch.float8_e4m3fn:
raise ValueError(f"b must be float8_e4m3fn, but got {b.dtype}")
if d.dtype != torch.bfloat16:
raise ValueError(f"d must be bfloat16, but got {d.dtype}")
if masked_m.dtype != torch.int32:
raise ValueError(f"masked_m must be int32, but got {masked_m.dtype}")

# D must be N-major
if get_major_type_cd(d) != MajorTypeCD.NMajor:
raise ValueError(f"d must be N-major, but got {get_major_type_cd(d)}")

return True


@backend_requirement(
{},
common_check=_check_m_grouped_fp8_gemm_nt_masked_problem_size,
)
def m_grouped_fp8_gemm_nt_masked(
a_fp8: Tuple[torch.Tensor, torch.Tensor],
b_fp8: Tuple[torch.Tensor, torch.Tensor],
Expand All @@ -1445,20 +1552,6 @@ def m_grouped_fp8_gemm_nt_masked(
b, sfb = b_fp8
num_groups, m, k = a.shape
num_groups_, n, k_ = b.shape
num_groups__, m_, n_ = d.shape
num_groups___ = masked_m.numel()

# Type and shape checks
assert num_groups == num_groups_ == num_groups__ == num_groups___
assert m == m_ and n == n_ and k == k_
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
assert a.dtype == torch.float8_e4m3fn
assert b.dtype == torch.float8_e4m3fn
assert d.dtype == torch.bfloat16
assert masked_m.dtype == torch.int32

# D must be N-major
assert get_major_type_cd(d) == MajorTypeCD.NMajor

# Transform SFA and SFB into compute-required layout
recipe = get_default_recipe(sfa.dtype, sfb.dtype) if recipe is None else recipe
Expand Down
Loading