Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions include/cutlass/detail/collective/mixed_input_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,35 @@ struct MixedInputUtils {
static constexpr auto KernelConversionMode = Collective::KernelConversionMode;
static constexpr auto ModeHasScales = Collective::ModeHasScales;
static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable;
static constexpr bool UseNvfp4Block16ScaleBroadcast =
cute::is_same_v<RealSwappedElementA, cutlass::float_e2m1_t> &&
cute::is_same_v<ElementScale, cutlass::float_e4m3_t> &&
(int(size<1>(SmemLayoutScale{})) > 1);

static constexpr auto
get_mma_smem_layout_scale() {
if constexpr (UseNvfp4Block16ScaleBroadcast) {
auto compact_layout = SmemLayoutScale{};
constexpr int ScaleK = int(size<1>(SmemLayoutScale{}));
static_assert(int(size<0>(SmemLayoutScale{})) % 16 == 0,
"NVFP4 scale broadcast assumes 16-row scale atoms.");
auto compact_k_stride =
compact_layout(_0{}, _1{}, _0{}) - compact_layout(_0{}, _0{}, _0{});
auto broadcast_layout = make_layout(
make_shape(shape<0>(compact_layout),
make_shape(Int<16>{}, Int<ScaleK>{}),
shape<2>(compact_layout)),
make_stride(stride<0>(compact_layout),
make_stride(Int<0>{}, compact_k_stride),
stride<2>(compact_layout)));
static_assert(cute::cosize_v<decltype(broadcast_layout)> ==
cute::cosize_v<SmemLayoutScale>);
return broadcast_layout;
}
else {
return SmemLayoutScale{};
}
}

public:
static constexpr auto
Expand Down Expand Up @@ -664,8 +693,14 @@ struct MixedInputUtils {

copy(smem_tiled_copy_A, tCsA(_,_,k_block,read_stage), tCrA_copy_view(_,_,k_block));

if (k_block == 0) {
// We are starting a new k-tile so copy the scale
bool copy_extra_inputs = k_block == 0;
if constexpr (size<1>(SmemLayoutScale{}) != 1) {
copy_extra_inputs = true;
}

if (copy_extra_inputs) {
// One-scale-per-tile kernels only refresh at the first GMMA k-block.
// NVFP4 block-16 kernels use a broadcast MMA view over compact scale columns.
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// nothing to do
}
Expand Down Expand Up @@ -862,7 +897,7 @@ struct MixedInputUtils {
}
else if constexpr (UseScaleLookupTable) {
constexpr int num_elements = decltype(size(src))::value;
static_assert(is_same_v<RealSwappedElementA, cutlass::int4b_t> || is_same_v<RealSwappedElementA, cutlass::float_e2m1_t>,
static_assert(is_same_v<RealSwappedElementA, cutlass::int4b_t> || is_same_v<RealSwappedElementA, cutlass::float_e2m1_t>,
"Lookup table supports int4b_t (Two's Complement) and float_e2m1_t (E2M1/FP4) quant types.");
static_assert(sizeof_bits_v<ElementScale> == 64, "Lookup table only supports 8 8bit scale values now.");
static_assert(num_elements % 4 == 0 && num_elements >= 4, "Lookup table requires a vector size of 4x when converting.");
Expand All @@ -886,7 +921,7 @@ struct MixedInputUtils {
{
auto&& scale_neg_ = reinterpret_cast<cutlass::Array<uint32_t, 2> const&>(scales_neg_vm_(i));
auto&& scale_pos_ = reinterpret_cast<cutlass::Array<uint32_t, 2> &>(scales_pos_vm_(i));

// Accept CUTLASS pseudo-FP as well
if constexpr (cutlass::platform::is_floating_point<RealSwappedElementA>::value ||
cute::is_same_v<RealSwappedElementA, cutlass::float_e2m1_t>) {
Expand Down Expand Up @@ -1022,7 +1057,7 @@ struct MixedInputUtils {
Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, pack));

cute::transform(src_arr, dst_arr, Converter::convert);

if constexpr (ModeHasScales) {

auto const& scales = cute::get<1>(partitioned_extra_info)(_,_,_,k_block);
Expand Down Expand Up @@ -1154,7 +1189,7 @@ struct MixedInputUtils {
return cute::make_tuple();
}
else if constexpr (UseScaleLookupTable) {
Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE)
Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), get_mma_smem_layout_scale());// (BLK_M,BLK_SCALE_K,PIPE)
Tensor tCsS = mma_thread_slice.partition_A(sS);
Tensor tCrS_neg = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout());
Tensor tCrS_pos = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout());
Expand All @@ -1164,15 +1199,15 @@ struct MixedInputUtils {
}
}
else if constexpr (ModeHasScales) {
Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE)
Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), get_mma_smem_layout_scale());// (BLK_M,BLK_SCALE_K,PIPE)
Tensor tCsS = mma_thread_slice.partition_A(sS);
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout());

if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(tCsS, tCrS);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE)
Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), get_mma_smem_layout_scale());// (BLK_M,BLK_SCALE_K,PIPE)
Tensor tCsZ = mma_thread_slice.partition_A(sZ);
Tensor tCrZ = make_tensor<ElementZero>(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).layout());
return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,17 @@ struct CollectiveMma<

static constexpr int NumProducerThreadEvents = 1;

using SmemLayoutAtomScale = Layout<Shape<decltype(cute::shape<0>(SwappedSmemLayoutAtomA{})), cute::Int<1>>>;
static constexpr bool UseNvfp4Block16Scales =
cute::is_same_v<RealSwappedElementA, cutlass::float_e2m1_t> &&
cute::is_same_v<NonVoidElementScale, cutlass::float_e4m3_t> &&
((int(size<2>(TileShape{})) % 16) == 0);
using ScaleAtomM =
cute::conditional_t<UseNvfp4Block16Scales, cute::Int<16>,
decltype(cute::shape<0>(SwappedSmemLayoutAtomA{}))>;
static constexpr int ScaleAtomK =
UseNvfp4Block16Scales ? int(size<2>(TileShape{})) / 16 : 1;
using SmemLayoutAtomScale =
Layout<Shape<ScaleAtomM, cute::Int<ScaleAtomK>>>;
using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{})));

static_assert(cute::rank(SwappedSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
Expand Down Expand Up @@ -234,9 +244,8 @@ struct CollectiveMma<
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");

// To relax them, we need to handle loading more than 1 row of scales for every main loop iteration.
// We must also handle updating the pipeline transaction bytes on the fly.
static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1.");
static_assert(size<1>(SmemLayoutAtomScale{}) == 1 || UseNvfp4Block16Scales,
"Only NVFP4 block-16 scales may use multiple scale columns per K tile.");

private:
static constexpr ConversionMode
Expand Down Expand Up @@ -378,6 +387,11 @@ struct CollectiveMma<
init_M = get<1>(init_shape);
init_N = get<0>(init_shape);
}
if constexpr (IsGroupedGemmKernel) {
init_M = cute::max(init_M, int(size<0>(TileShape{})));
init_N = cute::max(init_N, int(size<1>(TileShape{})));
init_K = cute::max(init_K, int(size<2>(TileShape{})));
}
// Batches/Groups are managed by using appropriate pointers to input matrices
const uint32_t mock_L = 1;
SwappedElementA const* ptr_A_first_batch;
Expand Down Expand Up @@ -491,7 +505,9 @@ struct CollectiveMma<
else if constexpr (ModeHasScales) {
auto scale_k = ceil_div(init_K, args.chunk_size);
ElementScale const* ptr_S = reinterpret_cast<ElementScale const*>(args.ptr_S);
StrideScale dS{};
StrideScale dS =
make_stride(Int<1>{}, static_cast<int64_t>(init_M),
static_cast<int64_t>(init_M) * scale_k);
Tensor tensor_scale = make_tensor(detail::get_logical_ptr(ptr_S), make_layout(make_shape(init_M,scale_k,mock_L), dS));
tma_load_scale = make_tma_copy<TmaElementScale>(
GmemTiledCopyScale{},
Expand Down Expand Up @@ -596,8 +612,14 @@ struct CollectiveMma<
const int scale_k = ceil_div(K, args.chunk_size);
constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
implementable = implementable && (args.chunk_size == K || ((args.chunk_size % size<2>(TileShape{})) == 0));
implementable = implementable && args.chunk_size != 0;
if (args.chunk_size != 0) {
implementable = implementable &&
(args.chunk_size == K ||
((args.chunk_size % size<2>(TileShape{})) == 0) ||
(UseNvfp4Block16Scales &&
((int(size<2>(TileShape{})) % args.chunk_size) == 0)));
}
implementable = implementable && (args.ptr_S != nullptr);
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
implementable = implementable && (args.ptr_Z == nullptr);
Expand Down
Loading