@@ -72,7 +72,7 @@ template <
7272CUTE_DEVICE void moe_gemm (ATensor const &A, // (M,K)
7373 BTensor const &B, // (N,K)
7474 DTensor &D, // (M,N)
75- Coord<int , int , cute::Underscore, int > blk_coord,
75+ Coord<int , int , cute::Underscore, int > & blk_coord,
7676 TiledMMA const &mma) {
7777 auto item = sycl::ext::oneapi::this_work_item::get_nd_item<3 >();
7878 auto local_id = item.get_local_linear_id ();
@@ -100,10 +100,10 @@ CUTE_DEVICE void moe_gemm(ATensor const &A, // (M,K)
100100 auto thr_copy_b = tiled_copy_b.get_slice (local_id);
101101 auto thr_copy_d = tiled_copy_d.get_slice (local_id);
102102
103- auto tCrA = thr_mma.partition_sg_fragment_A (gA (_, _, 0 ));
104- auto tCrB = thr_mma.partition_sg_fragment_B (gB (_, _, 0 ));
105- auto tCrD = thr_mma.partition_sg_fragment_C (gD );
106- auto tCrD_final = thr_copy_d.partition_sg_fragment_S (gD );
103+ auto tDrA = thr_mma.partition_sg_fragment_A (gA (_, _, 0 ));
104+ auto tDrB = thr_mma.partition_sg_fragment_B (gB (_, _, 0 ));
105+ auto tDrD = thr_mma.partition_sg_fragment_C (gD );
106+ auto tDrD_final = thr_copy_d.partition_sg_fragment_S (gD );
107107
108108 auto tArA = thr_copy_a.partition_sg_fragment_D (gA (_, _, 0 ));
109109 auto tBrB = thr_copy_b.partition_sg_fragment_D (gB (_, _, 0 ));
@@ -145,14 +145,227 @@ CUTE_DEVICE void moe_gemm(ATensor const &A, // (M,K)
145145 prefetch (prefetch_b, pBgB (_, _, _, prefetch_k));
146146 }
147147
148- reorder (tArA, tCrA );
149- reorder (tBrB, tCrB );
148+ reorder (tArA, tDrA );
149+ reorder (tBrB, tDrB );
150150
151- cute::gemm (mma, tCrA, tCrB, tCrD );
151+ cute::gemm (mma, tDrA, tDrB, tDrD );
152152 barrier_wait (barrier_scope);
153153 }
154- reorder (tCrD, tCrD_final);
155- copy (tiled_copy_d, tCrD_final, tCgD);
154+ reorder (tDrD, tDrD_final);
155+ copy (tiled_copy_d, tDrD_final, tCgD);
156+ }
157+
158+ template <class GmemTiledCopyA , class GmemTiledCopyB , class GmemTiledCopyD ,
159+ int SG_N, int WG_N, int q_group_size, class ATensor , class BTensor ,
160+ class STensor , class DTensor , class TiledMMA ,
161+ class = std::enable_if_t <
162+ !cute::is_void_v<typename STensor::element_type> &&
163+ is_any_of_v<typename BTensor::element_type, float_e2m1_t ,
164+ float_e4m3_t , float_e5m2_t , int4_t > &&
165+ is_any_of_v<typename STensor::element_type, float_ue8m0_t , half_t ,
166+ bfloat16_t > &&
167+ is_any_of_v<typename ATensor::element_type, bfloat16_t , half_t >>>
168+ CUTE_DEVICE void moe_gemm (ATensor const &A, // (M,K)
169+ BTensor const &B, // (N,K)
170+ STensor const &S, // (K/q_group_size, N)
171+ DTensor &D, // (M,N)
172+ Coord<int , int , cute::Underscore, int > &blk_coord,
173+ TiledMMA const &mma) {
174+ auto item = sycl::ext::oneapi::this_work_item::get_nd_item<2 >();
175+ auto wg_m = int (item.get_group (1 ));
176+ auto wg_n = int (item.get_group (0 ));
177+ auto local_id = int (item.get_local_id (0 ));
178+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group ();
179+ uint32_t sg_id = sg.get_group_linear_id ();
180+ uint32_t lane = sg.get_local_linear_id ();
181+
182+ auto total_N = get<0 >(B.shape ());
183+
184+ Tensor cA = make_identity_tensor (A.shape ()); // (M,K)
185+ Tensor cB = make_identity_tensor (B.shape ()); // (N,K)
186+ Tensor cD = make_identity_tensor (D.shape ()); // (M,N)
187+ Tensor cS = make_identity_tensor (S.shape ()); // (K/q_group_size,N)
188+ Tensor cScales_per_sg =
189+ make_identity_tensor (make_shape (Int<1 >{}, Int<SG_N>{}));
190+ auto wg_tile = mma.tile_mnk ();
191+ auto wg_coord = make_coord (wg_m, wg_n, 0 );
192+
193+ Tensor gA = local_tile (cA, select<0 , 2 >(wg_tile),
194+ make_coord (wg_m, _)); // (BLK_M,BLK_K,k)
195+ Tensor gB = local_tile (cB, select<1 , 2 >(wg_tile),
196+ make_coord (wg_n, _)); // (BLK_N,BLK_K,k)
197+ Tensor gD =
198+ local_tile (cD, wg_tile, wg_coord, Step<_1, _1, X>{}); // (BLK_M,BLK_N)
199+
200+ constexpr int num_N_SG_tiles = WG_N / SG_N;
201+ constexpr int num_scales_per_col = (SG_N == 32 ) ? 4 : 2 ;
202+
203+ // When we use E8M0, the compiler behaves differently & loads more data than
204+ // needed. The rest is discarded.
205+ // The scales might be FP16 or BF16 in case of int4 weights
206+ using scaleLoadType =
207+ conditional_t <is_same_v<typename STensor::element_type, float_ue8m0_t >,
208+ int8_t , int16_t >;
209+
210+ auto S_tile = coalesce (local_tile (S, make_shape (Int<1 >{}, get<1 >(wg_tile)),
211+ make_coord (_, wg_n)));
212+
213+ auto copy_a = get_block_2d_copy_A<GmemTiledCopyA>(mma, A);
214+ auto copy_b = get_block_2d_copy_B<GmemTiledCopyB>(mma, B);
215+ auto copy_d = get_block_2d_copy_D<GmemTiledCopyD>(mma, D);
216+
217+ auto thr_mma = mma.get_slice (local_id);
218+ auto thr_copy_a = copy_a.get_slice (local_id);
219+ auto thr_copy_b = copy_b.get_slice (local_id);
220+ auto thr_copy_d = copy_d.get_slice (local_id);
221+
222+ auto tDrA = thr_mma.partition_sg_fragment_A (gA (_, _, 0 ));
223+ auto tDrB = thr_mma.partition_sg_fragment_B (gB (_, _, 0 ));
224+
225+ auto tArA = thr_copy_a.partition_sg_fragment_D (gA (_, _, 0 ));
226+ auto tBrB = thr_copy_b.partition_sg_fragment_D (gB (_, _, 0 ));
227+ auto tDrD = thr_mma.partition_sg_fragment_C (gD );
228+
229+ Tensor tAgA = thr_copy_a.partition_S (gA );
230+ Tensor tBgB = thr_copy_b.partition_S (gB );
231+ Tensor tDgD = thr_copy_d.partition_D (gD );
232+
233+ auto prefetch_a = make_block_2d_prefetch (copy_a);
234+ auto prefetch_b = make_block_2d_prefetch (copy_b);
235+
236+ auto thr_prefetch_A = prefetch_a.get_slice (local_id);
237+ auto thr_prefetch_B = prefetch_b.get_slice (local_id);
238+
239+ auto pAgA = thr_prefetch_A.partition_S (gA );
240+ auto pBgB = thr_prefetch_B.partition_S (gB );
241+
242+ const int prefetch_dist = 3 ;
243+
244+ constexpr int barrier_scope = 2 ;
245+
246+ int k_tile_count = ceil_div (shape<1 >(A), get<2 >(wg_tile));
247+ int k_tile_prefetch = 0 ;
248+ constexpr int num_threads_per_sg = 16 ;
249+
250+ typename STensor::element_type
251+ frag[num_scales_per_col / 2 ]; // per-thread registers (compiler
252+ // will keep in regs)
253+ float frag_fp32[num_scales_per_col];
254+ // assuming SG_K = WG_K
255+ constexpr int frequency_scale_change = q_group_size / get<2 >(wg_tile);
256+ Tensor scales_e8m0 =
257+ make_tensor (make_rmem_ptr (frag),
258+ make_layout (make_shape (Int<num_scales_per_col / 2 >{})));
259+ Tensor scales_float =
260+ make_tensor (make_rmem_ptr (frag_fp32),
261+ make_layout (make_shape (Int<num_scales_per_col>{})));
262+
263+ auto srcTVLayout = make_layout (
264+ make_shape (Int<num_threads_per_sg>{}, Int<num_scales_per_col / 2 >{}),
265+ make_stride (Int<1 >{}, Int<num_threads_per_sg>{}));
266+ auto dstTVLayout = make_layout (
267+ make_shape (make_shape (Int<2 >{}, Int<num_threads_per_sg / 2 >{}),
268+ make_shape (Int<num_scales_per_col / 2 >{})),
269+ make_stride (make_stride (Int<0 >{}, Int<1 >{}), make_stride (Int<8 >{})));
270+ auto scales_e8m0_sg_tensor = make_subgroup_tensor (scales_e8m0, srcTVLayout);
271+ auto scales_float_sg_tensor = make_subgroup_tensor (scales_float, dstTVLayout);
272+
273+ /* Warm up loops with prefetch to L1 */
274+ CUTE_UNROLL
275+ for (; k_tile_prefetch < prefetch_dist; k_tile_prefetch++) {
276+ prefetch (prefetch_a, pAgA (_, _, _, k_tile_prefetch));
277+ prefetch (prefetch_b, pBgB (_, _, _, k_tile_prefetch));
278+ }
279+ /* Main loop */
280+ for (int k_tile = 0 ; k_tile < k_tile_count; k_tile++, k_tile_prefetch++) {
281+ barrier_arrive (barrier_scope);
282+ copy (copy_b, tBgB (_, _, _, k_tile), tBrB);
283+ prefetch (prefetch_b, pBgB (_, _, _, k_tile_prefetch));
284+ reorder (tBrB, tDrB);
285+
286+ if (k_tile % frequency_scale_change == 0 ) {
287+ auto scales_tensor = make_tensor (
288+ make_gmem_ptr (reinterpret_cast <scaleLoadType *>(
289+ static_cast <void *>(cute::raw_pointer_cast (
290+ S_tile.data () + (SG_N * (sg_id % num_N_SG_tiles)) +
291+ (k_tile / frequency_scale_change) * total_N)))),
292+ make_layout (make_shape (Int<1 >{}, Int<SG_N>{})));
293+ auto copy_scales = make_block_2d_copy (
294+ XE_LOAD_2D<sizeof_bits_v<typename STensor::element_type>, 1 , SG_N,
295+ SG_N>{},
296+ scales_tensor);
297+ auto thr_copy_scales = copy_scales.get_slice (lane);
298+ auto scales_per_thread = thr_copy_scales.partition_S (cScales_per_sg);
299+ copy (copy_scales, scales_per_thread (_, 0 , 0 ), scales_e8m0);
300+ reorder (scales_e8m0_sg_tensor, scales_float_sg_tensor);
301+ if (k_tile != (k_tile_count - frequency_scale_change)) {
302+ auto next_scales_tensor = make_tensor (
303+ make_gmem_ptr (reinterpret_cast <scaleLoadType *>(
304+ static_cast <void *>(cute::raw_pointer_cast (
305+ S_tile.data () + (SG_N * (sg_id % num_N_SG_tiles)) +
306+ ((k_tile / frequency_scale_change) + 1 ) * total_N)))),
307+ make_layout (make_shape (Int<1 >{}, Int<SG_N>{})));
308+ auto prefetch_scales = make_block_2d_prefetch<1 >(
309+ make_shape (Int<1 >{}, Int<SG_N>{}), next_scales_tensor);
310+ auto thr_prefetch_scales = prefetch_scales.get_slice (lane);
311+ auto pSgS = thr_prefetch_scales.partition_S (cScales_per_sg);
312+ prefetch (prefetch_scales, pSgS (_, 0 , 0 ));
313+ }
314+ }
315+ copy (copy_a, tAgA (_, _, _, k_tile), tArA);
316+ prefetch (prefetch_a, pAgA (_, _, _, k_tile_prefetch));
317+ reorder (tArA, tDrA);
318+ // Instead of hardcoding, figure out CuTe algebra based
319+ // transformations that can lead to generic code.
320+ auto scale0 = scales_float_sg_tensor[0 ];
321+ auto scale1 = scales_float_sg_tensor[1 ];
322+ if (num_scales_per_col == 4 ) {
323+ auto scale2 = scales_float_sg_tensor[2 ];
324+ auto scale3 = scales_float_sg_tensor[3 ];
325+ CUTE_UNROLL
326+ for (int i = 0 ; i < 16 ; i += 2 ) {
327+ tDrB[i] = static_cast <typename ATensor::element_type>(
328+ scale0 * static_cast <float >(tDrB[i]));
329+ tDrB[i + 1 ] = static_cast <typename ATensor::element_type>(
330+ scale1 * static_cast <float >(tDrB[i + 1 ]));
331+ }
332+ CUTE_UNROLL
333+ for (int i = 16 ; i < 32 ; i += 2 ) {
334+ tDrB[i] = static_cast <typename ATensor::element_type>(
335+ scale2 * static_cast <float >(tDrB[i]));
336+ tDrB[i + 1 ] = static_cast <typename ATensor::element_type>(
337+ scale3 * static_cast <float >(tDrB[i + 1 ]));
338+ }
339+ CUTE_UNROLL
340+ for (int i = 32 ; i < 48 ; i += 2 ) {
341+ tDrB[i] = static_cast <typename ATensor::element_type>(
342+ scale0 * static_cast <float >(tDrB[i]));
343+ tDrB[i + 1 ] = static_cast <typename ATensor::element_type>(
344+ scale1 * static_cast <float >(tDrB[i + 1 ]));
345+ }
346+ CUTE_UNROLL
347+ for (int i = 48 ; i < 64 ; i += 2 ) {
348+ tDrB[i] = static_cast <typename ATensor::element_type>(
349+ scale2 * static_cast <float >(tDrB[i]));
350+ tDrB[i + 1 ] = static_cast <typename ATensor::element_type>(
351+ scale3 * static_cast <float >(tDrB[i + 1 ]));
352+ }
353+ } else {
354+ CUTE_UNROLL
355+ for (int i = 0 ; i < 32 ; i += 2 ) {
356+ tDrB[i] = static_cast <typename ATensor::element_type>(
357+ scale0 * static_cast <float >(tDrB[i]));
358+ tDrB[i + 1 ] = static_cast <typename ATensor::element_type>(
359+ scale1 * static_cast <float >(tDrB[i + 1 ]));
360+ }
361+ }
362+
363+ gemm (mma, tDrA, tDrB, tDrD);
364+ barrier_wait (barrier_scope);
365+ }
366+ auto tDrD_final = thr_copy_d.partition_sg_fragment_S (gD );
367+ reorder (tDrD, tDrD_final);
368+ copy (copy_d, tDrD_final, tDgD);
156369}
157370
158371} // namespace MoE
0 commit comments