Skip to content

Commit f69ee6a

Browse files
committed
More MoE GEMMs
1 parent a6b0b5f commit f69ee6a

File tree

3 files changed

+289
-41
lines changed

3 files changed

+289
-41
lines changed

examples/12_bmg_moe_gemm_cute_interface/12_bmg_moe_gemm_cute_interface.cpp

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -193,23 +193,43 @@ struct VerificationHelper {
193193
};
194194
///////////////////////////////////////////////////////////////////////////////////////////////////
195195

196-
template <class TA, class TB> auto choose_tiled_mma(TA *A, TB *B) {
196+
template <typename TA, typename TB, typename TD> auto choose_mma_op() {
197+
if constexpr (is_any_of_v<TA, bfloat16_t, half_t> &&
198+
is_any_of_v<TB, bfloat16_t, half_t>) {
199+
return XE_DPAS_TT<8, float, TA, TB>{};
200+
} else if constexpr (is_complete_v<XE_DPAS_TT<8, TD, TA, TB>>) {
201+
return XE_DPAS_TT<8, TD, TA, TB>{};
202+
} else if constexpr (is_same_v<TA, cute::bfloat16_t>) {
203+
return XE_DPAS_TT<8, float, cute::bfloat16_t>{};
204+
} else { /* Use f16 by default as upconversion sequences are typically faster
205+
*/
206+
return XE_DPAS_TT<8, float, cute::half_t>{};
207+
}
208+
}
209+
210+
template <typename TA, typename TB, typename TD>
211+
auto choose_tiled_mma(TA *const &A, TB *const &B, TD *D) {
197212
using TA_non_CV = cutlass::platform::remove_cv_t<TA>;
198213
using TB_non_CV = cutlass::platform::remove_cv_t<TB>;
199-
auto op = XE_DPAS_TT<8, float, TA_non_CV, TB_non_CV>{};
214+
auto op = choose_mma_op<TA_non_CV, TB_non_CV, TD>();
200215

201-
using WGTile = Shape<_256, _128, _32>; // 256x128 WG tile size
202-
using SGLayout =
216+
constexpr bool use_4x8_sg = (sizeof_bits_v<TB> < sizeof_bits_v<TA>);
217+
using WGTileShape =
218+
conditional_t<use_4x8_sg, Shape<_256, _256, _32>, Shape<_256, _128, _32>>;
219+
using SGLayout8x2 =
203220
Layout<Shape<_8, _2, _1>, Stride<_2, _1, _0>>; // 8x2 SG tiling, n-major
204-
205-
using MMA = typename TiledMMAHelper<MMA_Atom<decltype(op)>, Layout<WGTile>,
206-
SGLayout>::TiledMMA;
207-
208-
return MMA{};
221+
using SGLayout4x8 =
222+
Layout<Shape<_4, _8, _1>, Stride<_8, _1, _0>>; // 4x8 SG tiling, n-major
223+
using SGLayout = conditional_t<use_4x8_sg, SGLayout4x8, SGLayout8x2>;
224+
225+
using MMA = typename TiledMMAHelper<MMA_Atom<decltype(op)>,
226+
Layout<WGTileShape>, SGLayout>::TiledMMA;
227+
auto mma = MMA{};
228+
return mma;
209229
}
210230

211231
// type tag to define a unique sycl kernel name
212-
template <typename, typename, typename, char, char> class GemmCuteName;
232+
template <typename, typename, typename, char, char, int> class GemmCuteName;
213233

214234
template <char layoutA, char layoutB, typename ElementA, typename ElementB,
215235
typename ElementS, typename ElementD>
@@ -236,20 +256,34 @@ void MoEGEMMLauncher(const ElementA *activations, const ElementB *weights,
236256
auto dummy_group_problem_shape =
237257
cutlass::gemm::GroupProblemShape<Shape<int, int, int>>{
238258
1, &dummy_problem_shape, nullptr};
239-
using TileShape = Shape<_256, _128, _32>;
259+
constexpr bool use_4x8_sg =
260+
(sizeof_bits_v<ElementB> < sizeof_bits_v<ElementA>);
261+
using WGTileShape =
262+
conditional_t<use_4x8_sg, Shape<_256, _256, _32>, Shape<_256, _128, _32>>;
240263
using ClusterShape = Shape<_1, _1, _1>;
241264
auto scheduler_params =
242265
PersistentTileSchedulerXeMoE<ProblemShape>::to_underlying_arguments(
243-
dummy_group_problem_shape, TileShape{}, ClusterShape{}, hw_info,
266+
dummy_group_problem_shape, WGTileShape{}, ClusterShape{}, hw_info,
244267
PersistentTileSchedulerXeMoE<ProblemShape>::Arguments{
245268
1, RasterOrderOptions::AlongN});
246269
auto group_distribution =
247270
PersistentTileSchedulerXeMoE<ProblemShape>::get_grid_shape(
248-
scheduler_params, dummy_group_problem_shape, TileShape{},
271+
scheduler_params, dummy_group_problem_shape, WGTileShape{},
249272
ClusterShape{}, hw_info,
250273
PersistentTileSchedulerXeMoE<ProblemShape>::Arguments{
251274
1, RasterOrderOptions::AlongN});
252-
auto mma = choose_tiled_mma(activations, weights);
275+
276+
using SGLayout8x2 =
277+
Layout<Shape<_8, _2, _1>, Stride<_2, _1, _0>>; // 8x2 SG tiling, n-major
278+
using SGLayout4x8 =
279+
Layout<Shape<_4, _8, _1>, Stride<_8, _1, _0>>; // 4x8 SG tiling, n-major
280+
using SGLayout = conditional_t<use_4x8_sg, SGLayout4x8, SGLayout8x2>;
281+
282+
auto mma = choose_tiled_mma(activations, weights, outputs);
283+
constexpr auto wg_n = get<1>(mma.tile_mnk());
284+
constexpr auto sg_n = wg_n / get<1>(SGLayout{}.shape());
285+
constexpr auto q_group_size = 32;
286+
253287
auto MaxThreadsPerWorkgroup = size(mma);
254288
dim3 local_range{MaxThreadsPerWorkgroup, 1, 1};
255289

@@ -268,16 +302,17 @@ void MoEGEMMLauncher(const ElementA *activations, const ElementB *weights,
268302

269303
GPU_Clock timer;
270304
timer.start();
271-
auto event = Q.parallel_for<
272-
GemmCuteName<ElementA, ElementB, ElementD, layoutA, layoutB>>(
305+
auto event = Q.parallel_for<GemmCuteName<ElementA, ElementB, ElementD,
306+
layoutA, layoutB, q_group_size>>(
273307
sycl::nd_range<3>(global, local), kernel_props, [=](auto) {
274308
// Can also use void for copy atoms.
275309
// In that case, they will be chosen automatically.
276310
MoE::MoEGEMM<XE_LOAD_2D<16, 32, 32, 16>,
277311
XE_LOAD_2D_VNNI<16, 32, 16, 16>, XE_STORE_2D<16, 8, 32>,
278-
'R', 'R', 'R'>(activations, weights, scales, outputs, mma,
279-
num_rows_per_expert_device, num_experts,
280-
gemm_n, gemm_k, scheduler_params);
312+
'R', 'R', 'R', sg_n, wg_n, q_group_size>(
313+
activations, weights, scales, outputs, mma,
314+
num_rows_per_expert_device, num_experts, gemm_n, gemm_k,
315+
scheduler_params);
281316
});
282317
EventManager::getInstance().addEvent(event);
283318
Q.wait_and_throw();
@@ -413,8 +448,7 @@ int main(int argc, const char **argv) {
413448
{6, 13, 123, 28, 197, 0, 202, 69, 0, 6, 0, 21, 1434, 1582, 11, 0, 6,
414449
0, 7, 190, 4, 1700, 6, 434, 1886, 0, 14, 28, 8, 30, 25, 18},
415450
{5, 27, 1442, 18, 0, 6, 0, 73, 6, 781, 0, 1915, 291, 649, 98, 4,
416-
33, 77, 6, 22, 73, 9, 8, 587, 1486, 32, 10, 244, 37, 0, 100, 9}
417-
};
451+
33, 77, 6, 22, 73, 9, 8, 587, 1486, 32, 10, 244, 37, 0, 100, 9}};
418452

419453
for (int i = 0; i < num_layers; i++) {
420454
launcher(total_rows_for_each_expert[i], 5760, 2880, num_experts);

examples/12_bmg_moe_gemm_cute_interface/moe_gemms.hpp

Lines changed: 223 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ template <
7272
CUTE_DEVICE void moe_gemm(ATensor const &A, // (M,K)
7373
BTensor const &B, // (N,K)
7474
DTensor &D, // (M,N)
75-
Coord<int, int, cute::Underscore, int> blk_coord,
75+
Coord<int, int, cute::Underscore, int> &blk_coord,
7676
TiledMMA const &mma) {
7777
auto item = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
7878
auto local_id = item.get_local_linear_id();
@@ -100,10 +100,10 @@ CUTE_DEVICE void moe_gemm(ATensor const &A, // (M,K)
100100
auto thr_copy_b = tiled_copy_b.get_slice(local_id);
101101
auto thr_copy_d = tiled_copy_d.get_slice(local_id);
102102

103-
auto tCrA = thr_mma.partition_sg_fragment_A(gA(_, _, 0));
104-
auto tCrB = thr_mma.partition_sg_fragment_B(gB(_, _, 0));
105-
auto tCrD = thr_mma.partition_sg_fragment_C(gD);
106-
auto tCrD_final = thr_copy_d.partition_sg_fragment_S(gD);
103+
auto tDrA = thr_mma.partition_sg_fragment_A(gA(_, _, 0));
104+
auto tDrB = thr_mma.partition_sg_fragment_B(gB(_, _, 0));
105+
auto tDrD = thr_mma.partition_sg_fragment_C(gD);
106+
auto tDrD_final = thr_copy_d.partition_sg_fragment_S(gD);
107107

108108
auto tArA = thr_copy_a.partition_sg_fragment_D(gA(_, _, 0));
109109
auto tBrB = thr_copy_b.partition_sg_fragment_D(gB(_, _, 0));
@@ -145,14 +145,227 @@ CUTE_DEVICE void moe_gemm(ATensor const &A, // (M,K)
145145
prefetch(prefetch_b, pBgB(_, _, _, prefetch_k));
146146
}
147147

148-
reorder(tArA, tCrA);
149-
reorder(tBrB, tCrB);
148+
reorder(tArA, tDrA);
149+
reorder(tBrB, tDrB);
150150

151-
cute::gemm(mma, tCrA, tCrB, tCrD);
151+
cute::gemm(mma, tDrA, tDrB, tDrD);
152152
barrier_wait(barrier_scope);
153153
}
154-
reorder(tCrD, tCrD_final);
155-
copy(tiled_copy_d, tCrD_final, tCgD);
154+
reorder(tDrD, tDrD_final);
155+
copy(tiled_copy_d, tDrD_final, tCgD);
156+
}
157+
158+
template <class GmemTiledCopyA, class GmemTiledCopyB, class GmemTiledCopyD,
159+
int SG_N, int WG_N, int q_group_size, class ATensor, class BTensor,
160+
class STensor, class DTensor, class TiledMMA,
161+
class = std::enable_if_t<
162+
!cute::is_void_v<typename STensor::element_type> &&
163+
is_any_of_v<typename BTensor::element_type, float_e2m1_t,
164+
float_e4m3_t, float_e5m2_t, int4_t> &&
165+
is_any_of_v<typename STensor::element_type, float_ue8m0_t, half_t,
166+
bfloat16_t> &&
167+
is_any_of_v<typename ATensor::element_type, bfloat16_t, half_t>>>
168+
CUTE_DEVICE void moe_gemm(ATensor const &A, // (M,K)
169+
BTensor const &B, // (N,K)
170+
STensor const &S, // (K/q_group_size, N)
171+
DTensor &D, // (M,N)
172+
Coord<int, int, cute::Underscore, int> &blk_coord,
173+
TiledMMA const &mma) {
174+
auto item = sycl::ext::oneapi::this_work_item::get_nd_item<2>();
175+
auto wg_m = int(item.get_group(1));
176+
auto wg_n = int(item.get_group(0));
177+
auto local_id = int(item.get_local_id(0));
178+
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
179+
uint32_t sg_id = sg.get_group_linear_id();
180+
uint32_t lane = sg.get_local_linear_id();
181+
182+
auto total_N = get<0>(B.shape());
183+
184+
Tensor cA = make_identity_tensor(A.shape()); // (M,K)
185+
Tensor cB = make_identity_tensor(B.shape()); // (N,K)
186+
Tensor cD = make_identity_tensor(D.shape()); // (M,N)
187+
Tensor cS = make_identity_tensor(S.shape()); // (K/q_group_size,N)
188+
Tensor cScales_per_sg =
189+
make_identity_tensor(make_shape(Int<1>{}, Int<SG_N>{}));
190+
auto wg_tile = mma.tile_mnk();
191+
auto wg_coord = make_coord(wg_m, wg_n, 0);
192+
193+
Tensor gA = local_tile(cA, select<0, 2>(wg_tile),
194+
make_coord(wg_m, _)); // (BLK_M,BLK_K,k)
195+
Tensor gB = local_tile(cB, select<1, 2>(wg_tile),
196+
make_coord(wg_n, _)); // (BLK_N,BLK_K,k)
197+
Tensor gD =
198+
local_tile(cD, wg_tile, wg_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N)
199+
200+
constexpr int num_N_SG_tiles = WG_N / SG_N;
201+
constexpr int num_scales_per_col = (SG_N == 32) ? 4 : 2;
202+
203+
// When we use E8M0, the compiler behaves differently & loads more data than
204+
// needed. The rest is discarded.
205+
// The scales might be FP16 or BF16 in case of int4 weights
206+
using scaleLoadType =
207+
conditional_t<is_same_v<typename STensor::element_type, float_ue8m0_t>,
208+
int8_t, int16_t>;
209+
210+
auto S_tile = coalesce(local_tile(S, make_shape(Int<1>{}, get<1>(wg_tile)),
211+
make_coord(_, wg_n)));
212+
213+
auto copy_a = get_block_2d_copy_A<GmemTiledCopyA>(mma, A);
214+
auto copy_b = get_block_2d_copy_B<GmemTiledCopyB>(mma, B);
215+
auto copy_d = get_block_2d_copy_D<GmemTiledCopyD>(mma, D);
216+
217+
auto thr_mma = mma.get_slice(local_id);
218+
auto thr_copy_a = copy_a.get_slice(local_id);
219+
auto thr_copy_b = copy_b.get_slice(local_id);
220+
auto thr_copy_d = copy_d.get_slice(local_id);
221+
222+
auto tDrA = thr_mma.partition_sg_fragment_A(gA(_, _, 0));
223+
auto tDrB = thr_mma.partition_sg_fragment_B(gB(_, _, 0));
224+
225+
auto tArA = thr_copy_a.partition_sg_fragment_D(gA(_, _, 0));
226+
auto tBrB = thr_copy_b.partition_sg_fragment_D(gB(_, _, 0));
227+
auto tDrD = thr_mma.partition_sg_fragment_C(gD);
228+
229+
Tensor tAgA = thr_copy_a.partition_S(gA);
230+
Tensor tBgB = thr_copy_b.partition_S(gB);
231+
Tensor tDgD = thr_copy_d.partition_D(gD);
232+
233+
auto prefetch_a = make_block_2d_prefetch(copy_a);
234+
auto prefetch_b = make_block_2d_prefetch(copy_b);
235+
236+
auto thr_prefetch_A = prefetch_a.get_slice(local_id);
237+
auto thr_prefetch_B = prefetch_b.get_slice(local_id);
238+
239+
auto pAgA = thr_prefetch_A.partition_S(gA);
240+
auto pBgB = thr_prefetch_B.partition_S(gB);
241+
242+
const int prefetch_dist = 3;
243+
244+
constexpr int barrier_scope = 2;
245+
246+
int k_tile_count = ceil_div(shape<1>(A), get<2>(wg_tile));
247+
int k_tile_prefetch = 0;
248+
constexpr int num_threads_per_sg = 16;
249+
250+
typename STensor::element_type
251+
frag[num_scales_per_col / 2]; // per-thread registers (compiler
252+
// will keep in regs)
253+
float frag_fp32[num_scales_per_col];
254+
// assuming SG_K = WG_K
255+
constexpr int frequency_scale_change = q_group_size / get<2>(wg_tile);
256+
Tensor scales_e8m0 =
257+
make_tensor(make_rmem_ptr(frag),
258+
make_layout(make_shape(Int<num_scales_per_col / 2>{})));
259+
Tensor scales_float =
260+
make_tensor(make_rmem_ptr(frag_fp32),
261+
make_layout(make_shape(Int<num_scales_per_col>{})));
262+
263+
auto srcTVLayout = make_layout(
264+
make_shape(Int<num_threads_per_sg>{}, Int<num_scales_per_col / 2>{}),
265+
make_stride(Int<1>{}, Int<num_threads_per_sg>{}));
266+
auto dstTVLayout = make_layout(
267+
make_shape(make_shape(Int<2>{}, Int<num_threads_per_sg / 2>{}),
268+
make_shape(Int<num_scales_per_col / 2>{})),
269+
make_stride(make_stride(Int<0>{}, Int<1>{}), make_stride(Int<8>{})));
270+
auto scales_e8m0_sg_tensor = make_subgroup_tensor(scales_e8m0, srcTVLayout);
271+
auto scales_float_sg_tensor = make_subgroup_tensor(scales_float, dstTVLayout);
272+
273+
/* Warm up loops with prefetch to L1 */
274+
CUTE_UNROLL
275+
for (; k_tile_prefetch < prefetch_dist; k_tile_prefetch++) {
276+
prefetch(prefetch_a, pAgA(_, _, _, k_tile_prefetch));
277+
prefetch(prefetch_b, pBgB(_, _, _, k_tile_prefetch));
278+
}
279+
/* Main loop */
280+
for (int k_tile = 0; k_tile < k_tile_count; k_tile++, k_tile_prefetch++) {
281+
barrier_arrive(barrier_scope);
282+
copy(copy_b, tBgB(_, _, _, k_tile), tBrB);
283+
prefetch(prefetch_b, pBgB(_, _, _, k_tile_prefetch));
284+
reorder(tBrB, tDrB);
285+
286+
if (k_tile % frequency_scale_change == 0) {
287+
auto scales_tensor = make_tensor(
288+
make_gmem_ptr(reinterpret_cast<scaleLoadType *>(
289+
static_cast<void *>(cute::raw_pointer_cast(
290+
S_tile.data() + (SG_N * (sg_id % num_N_SG_tiles)) +
291+
(k_tile / frequency_scale_change) * total_N)))),
292+
make_layout(make_shape(Int<1>{}, Int<SG_N>{})));
293+
auto copy_scales = make_block_2d_copy(
294+
XE_LOAD_2D<sizeof_bits_v<typename STensor::element_type>, 1, SG_N,
295+
SG_N>{},
296+
scales_tensor);
297+
auto thr_copy_scales = copy_scales.get_slice(lane);
298+
auto scales_per_thread = thr_copy_scales.partition_S(cScales_per_sg);
299+
copy(copy_scales, scales_per_thread(_, 0, 0), scales_e8m0);
300+
reorder(scales_e8m0_sg_tensor, scales_float_sg_tensor);
301+
if (k_tile != (k_tile_count - frequency_scale_change)) {
302+
auto next_scales_tensor = make_tensor(
303+
make_gmem_ptr(reinterpret_cast<scaleLoadType *>(
304+
static_cast<void *>(cute::raw_pointer_cast(
305+
S_tile.data() + (SG_N * (sg_id % num_N_SG_tiles)) +
306+
((k_tile / frequency_scale_change) + 1) * total_N)))),
307+
make_layout(make_shape(Int<1>{}, Int<SG_N>{})));
308+
auto prefetch_scales = make_block_2d_prefetch<1>(
309+
make_shape(Int<1>{}, Int<SG_N>{}), next_scales_tensor);
310+
auto thr_prefetch_scales = prefetch_scales.get_slice(lane);
311+
auto pSgS = thr_prefetch_scales.partition_S(cScales_per_sg);
312+
prefetch(prefetch_scales, pSgS(_, 0, 0));
313+
}
314+
}
315+
copy(copy_a, tAgA(_, _, _, k_tile), tArA);
316+
prefetch(prefetch_a, pAgA(_, _, _, k_tile_prefetch));
317+
reorder(tArA, tDrA);
318+
// Instead of hardcoding, figure out CuTe algebra based
319+
// transformations that can lead to generic code.
320+
auto scale0 = scales_float_sg_tensor[0];
321+
auto scale1 = scales_float_sg_tensor[1];
322+
if (num_scales_per_col == 4) {
323+
auto scale2 = scales_float_sg_tensor[2];
324+
auto scale3 = scales_float_sg_tensor[3];
325+
CUTE_UNROLL
326+
for (int i = 0; i < 16; i += 2) {
327+
tDrB[i] = static_cast<typename ATensor::element_type>(
328+
scale0 * static_cast<float>(tDrB[i]));
329+
tDrB[i + 1] = static_cast<typename ATensor::element_type>(
330+
scale1 * static_cast<float>(tDrB[i + 1]));
331+
}
332+
CUTE_UNROLL
333+
for (int i = 16; i < 32; i += 2) {
334+
tDrB[i] = static_cast<typename ATensor::element_type>(
335+
scale2 * static_cast<float>(tDrB[i]));
336+
tDrB[i + 1] = static_cast<typename ATensor::element_type>(
337+
scale3 * static_cast<float>(tDrB[i + 1]));
338+
}
339+
CUTE_UNROLL
340+
for (int i = 32; i < 48; i += 2) {
341+
tDrB[i] = static_cast<typename ATensor::element_type>(
342+
scale0 * static_cast<float>(tDrB[i]));
343+
tDrB[i + 1] = static_cast<typename ATensor::element_type>(
344+
scale1 * static_cast<float>(tDrB[i + 1]));
345+
}
346+
CUTE_UNROLL
347+
for (int i = 48; i < 64; i += 2) {
348+
tDrB[i] = static_cast<typename ATensor::element_type>(
349+
scale2 * static_cast<float>(tDrB[i]));
350+
tDrB[i + 1] = static_cast<typename ATensor::element_type>(
351+
scale3 * static_cast<float>(tDrB[i + 1]));
352+
}
353+
} else {
354+
CUTE_UNROLL
355+
for (int i = 0; i < 32; i += 2) {
356+
tDrB[i] = static_cast<typename ATensor::element_type>(
357+
scale0 * static_cast<float>(tDrB[i]));
358+
tDrB[i + 1] = static_cast<typename ATensor::element_type>(
359+
scale1 * static_cast<float>(tDrB[i + 1]));
360+
}
361+
}
362+
363+
gemm(mma, tDrA, tDrB, tDrD);
364+
barrier_wait(barrier_scope);
365+
}
366+
auto tDrD_final = thr_copy_d.partition_sg_fragment_S(gD);
367+
reorder(tDrD, tDrD_final);
368+
copy(copy_d, tDrD_final, tDgD);
156369
}
157370

158371
} // namespace MoE

0 commit comments

Comments
 (0)