Skip to content

Commit 6ef1dda

Browse files
committed
rename
Signed-off-by: Enwei Zhu <[email protected]>
1 parent 4cd1dfa commit 6ef1dda

File tree

4 files changed

+35
-36
lines changed

4 files changed

+35
-36
lines changed

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
import cutlass
2323
import cutlass.cute as cute
2424

25+
from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm import \
26+
Sm100BlockScaledContiguousGroupedGemmKernel
27+
from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm_finalize_fusion import \
28+
Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel
2529
from ..cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import \
2630
Sm100BlockScaledPersistentDenseGemmKernel
27-
from ..cute_dsl_kernels.blackwell.grouped_blockscaled_gemm_finalize_fusion import \
28-
Sm100BlockScaledPersistentGroupedGemmFinalizeFusionKernel
29-
from ..cute_dsl_kernels.blackwell.grouped_blockscaled_gemm_persistent import \
30-
Sm100BlockScaledPersistentGroupedGemmKernel
3131
from ..cute_dsl_kernels.blackwell.utils import make_ptr
3232

3333
class CuteDSLNVFP4BlackwellRunner(TunableRunner):
@@ -499,8 +499,8 @@ def inputs_pre_hook_finalize_fusion(
499499
device=num_non_exiting_tiles.device)
500500
return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales
501501

502-
class Sm100BlockScaledPersistentGroupedGemmRunner(TunableRunner):
503-
kernel_class = Sm100BlockScaledPersistentGroupedGemmKernel
502+
class Sm100BlockScaledContiguousGroupedGemmRunner(TunableRunner):
503+
kernel_class = Sm100BlockScaledContiguousGroupedGemmKernel
504504
kernel_cache = dict()
505505
tuning_config_cache = dict()
506506

@@ -730,7 +730,7 @@ def cute_dsl_nvfp4_grouped_gemm_blackwell(
730730
) -> torch.Tensor:
731731
tuner = AutoTuner.get()
732732

733-
runner = Sm100BlockScaledPersistentGroupedGemmRunner(
733+
runner = Sm100BlockScaledContiguousGroupedGemmRunner(
734734
num_experts, top_k, num_local_experts, local_expert_offset,
735735
tile_size, output_dtype, scaling_vector_size)
736736
inputs = [
@@ -769,9 +769,9 @@ def _(
769769
n = weight.size(1)
770770
return torch.empty(m, n, dtype=output_dtype, device=input.device)
771771

772-
class Sm100BlockScaledPersistentGroupedGemmFinalizeFusionRunner(
772+
class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionRunner(
773773
TunableRunner):
774-
kernel_class = Sm100BlockScaledPersistentGroupedGemmFinalizeFusionKernel
774+
kernel_class = Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel
775775
kernel_cache = dict()
776776
tuning_config_cache = dict()
777777

@@ -1038,7 +1038,7 @@ def cute_dsl_nvfp4_grouped_gemm_finalize_blackwell(
10381038
) -> torch.Tensor:
10391039
tuner = AutoTuner.get()
10401040

1041-
runner = Sm100BlockScaledPersistentGroupedGemmFinalizeFusionRunner(
1041+
runner = Sm100BlockScaledContiguousGroupedGemmFinalizeFusionRunner(
10421042
num_experts, top_k, num_local_experts, local_expert_offset,
10431043
tile_size, output_dtype, scaling_vector_size)
10441044
inputs = [

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/grouped_blockscaled_gemm_persistent.py renamed to tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from cutlass.cute.nvgpu import cpasync, tcgen05
5454

5555

56-
class Sm100BlockScaledPersistentGroupedGemmKernel:
56+
class Sm100BlockScaledContiguousGroupedGemmKernel:
5757
"""This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types
5858
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
5959
@@ -88,7 +88,7 @@ class Sm100BlockScaledPersistentGroupedGemmKernel:
8888
- Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors
8989
9090
Example:
91-
>>> gemm = Sm100BlockScaledPersistentGroupedGemmKernel(
91+
>>> gemm = Sm100BlockScaledContiguousGroupedGemmKernel(
9292
... sf_vec_size=16, mma_tiler_mn=(256, 128), cluster_shape_mn=(2, 1)
9393
... )
9494
>>> gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, max_active_clusters, stream)
@@ -2138,8 +2138,9 @@ def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
21382138
is_valid = False
21392139
return is_valid
21402140

2141-
@staticmethod
2141+
@classmethod
21422142
def can_implement(
2143+
cls,
21432144
ab_dtype: Type[cutlass.Numeric],
21442145
sf_dtype: Type[cutlass.Numeric],
21452146
sf_vec_size: int,
@@ -2198,24 +2199,22 @@ def can_implement(
21982199
"""
21992200
can_implement = True
22002201
# Skip unsupported types
2201-
if not Sm100BlockScaledPersistentGroupedGemmKernel.is_valid_dtypes_and_scale_factor_vec_size(
2202+
if not cls.is_valid_dtypes_and_scale_factor_vec_size(
22022203
ab_dtype, sf_dtype, sf_vec_size, acc_dtype, c_dtype
22032204
):
22042205
can_implement = False
22052206

22062207
# Skip unsupported layouts
2207-
if not Sm100BlockScaledPersistentGroupedGemmKernel.is_valid_layouts(
2208-
ab_dtype, c_dtype, a_major, b_major, c_major
2209-
):
2208+
if not cls.is_valid_layouts(ab_dtype, c_dtype, a_major, b_major, c_major):
22102209
can_implement = False
22112210

22122211
# Skip invalid mma tile shape and cluster shape
2213-
if not Sm100BlockScaledPersistentGroupedGemmKernel.is_valid_mma_tiler_and_cluster_shape(
2212+
if not cls.is_valid_mma_tiler_and_cluster_shape(
22142213
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, m_aligned
22152214
):
22162215
can_implement = False
22172216
# Skip illegal problem shape for load/store alignment
2218-
if not Sm100BlockScaledPersistentGroupedGemmKernel.is_valid_tensor_alignment(
2217+
if not cls.is_valid_tensor_alignment(
22192218
m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major
22202219
):
22212220
can_implement = False

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/grouped_blockscaled_gemm_finalize_fusion.py renamed to tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def atomic_add_func(rOut_epi_packed, scatter_out_offset, loc=None, ip=None):
210210
)
211211

212212

213-
class Sm100BlockScaledPersistentGroupedGemmFinalizeFusionKernel:
213+
class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel:
214214
"""This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types
215215
and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization.
216216
@@ -245,7 +245,7 @@ class Sm100BlockScaledPersistentGroupedGemmFinalizeFusionKernel:
245245
- Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors
246246
247247
Example:
248-
>>> gemm = Sm100BlockScaledPersistentGroupedGemmFinalizeFusionKernel(
248+
>>> gemm = Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel(
249249
... sf_vec_size=16, mma_tiler_mn=(256, 128), cluster_shape_mn=(2, 1)
250250
... )
251251
>>> gemm(
@@ -2147,8 +2147,9 @@ def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
21472147
is_valid = False
21482148
return is_valid
21492149

2150-
@staticmethod
2150+
@classmethod
21512151
def can_implement(
2152+
cls,
21522153
ab_dtype: Type[cutlass.Numeric],
21532154
sf_dtype: Type[cutlass.Numeric],
21542155
sf_vec_size: int,
@@ -2207,24 +2208,22 @@ def can_implement(
22072208
"""
22082209
can_implement = True
22092210
# Skip unsupported types
2210-
if not Sm100BlockScaledPersistentGroupedGemmFinalizeFusionKernel.is_valid_dtypes_and_scale_factor_vec_size(
2211+
if not cls.is_valid_dtypes_and_scale_factor_vec_size(
22112212
ab_dtype, sf_dtype, sf_vec_size, acc_dtype, out_dtype
22122213
):
22132214
can_implement = False
22142215

22152216
# Skip unsupported layouts
2216-
if not Sm100BlockScaledPersistentGroupedGemmFinalizeFusionKernel.is_valid_layouts(
2217-
ab_dtype, out_dtype, a_major, b_major, c_major
2218-
):
2217+
if not cls.is_valid_layouts(ab_dtype, out_dtype, a_major, b_major, c_major):
22192218
can_implement = False
22202219

22212220
# Skip invalid mma tile shape and cluster shape
2222-
if not Sm100BlockScaledPersistentGroupedGemmFinalizeFusionKernel.is_valid_mma_tiler_and_cluster_shape(
2221+
if not cls.is_valid_mma_tiler_and_cluster_shape(
22232222
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, m_aligned
22242223
):
22252224
can_implement = False
22262225
# Skip illegal problem shape for load/store alignment
2227-
if not Sm100BlockScaledPersistentGroupedGemmFinalizeFusionKernel.is_valid_tensor_alignment(
2226+
if not cls.is_valid_tensor_alignment(
22282227
m, n, k, l, ab_dtype, out_dtype, a_major, b_major, c_major
22292228
):
22302229
can_implement = False

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/dense_blockscaled_gemm_persistent.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1817,8 +1817,9 @@ def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape):
18171817
is_valid = False
18181818
return is_valid
18191819

1820-
@staticmethod
1820+
@classmethod
18211821
def can_implement(
1822+
cls,
18221823
ab_dtype: Type[cutlass.Numeric],
18231824
sf_dtype: Type[cutlass.Numeric],
18241825
sf_vec_size: int,
@@ -1859,20 +1860,20 @@ def can_implement(
18591860
"""
18601861
can_implement = True
18611862
# Skip unsupported types
1862-
if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_dtypes_and_scale_factor_vec_size(
1863+
if not cls.is_valid_dtypes_and_scale_factor_vec_size(
18631864
ab_dtype, sf_dtype, sf_vec_size, c_dtype):
18641865
can_implement = False
18651866
# Skip unsupported layouts
1866-
if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_layouts(
1867-
ab_dtype, c_dtype, a_major, b_major, c_major):
1867+
if not cls.is_valid_layouts(ab_dtype, c_dtype, a_major, b_major,
1868+
c_major):
18681869
can_implement = False
18691870
# Skip invalid mma tile shape and cluster shape
1870-
if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape(
1871-
mma_tiler_mn, cluster_shape_mn):
1871+
if not cls.is_valid_mma_tiler_and_cluster_shape(mma_tiler_mn,
1872+
cluster_shape_mn):
18721873
can_implement = False
18731874
# Skip illegal problem shape for load/store alignment
1874-
if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_tensor_alignment(
1875-
m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major):
1875+
if not cls.is_valid_tensor_alignment(m, n, k, l, ab_dtype, c_dtype,
1876+
a_major, b_major, c_major):
18761877
can_implement = False
18771878
return can_implement
18781879

0 commit comments

Comments
 (0)