Skip to content

Commit 33cae04

Browse files
authored
[CuTe] [Xe] Separate make_block_2d_copy_{C,D} APIs for loads/stores (#572)
Currently `make_block_2d_copy_C` always selects a block 2D store, but we really need separate APIs for C (loads) and D (stores). The default value type can also be different, in case of MMA atoms with different C/D types. This PR introduces C/D APIs. Some APIs are common between C/D, and are named `make_block_2d_copy_CD` to avoid duplication.
1 parent 7feb377 commit 33cae04

File tree

3 files changed

+59
-21
lines changed

3 files changed

+59
-21
lines changed

examples/cute/tutorial/xe_gemm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ gemm_device(ATensor const& A, // (M,K)
8686
/* Create block 2D TiledCopies */
8787
auto copy_a = make_block_2d_copy_A(mma, A);
8888
auto copy_b = make_block_2d_copy_B(mma, B);
89-
auto copy_c = make_block_2d_copy_C(mma, C);
89+
auto copy_c = make_block_2d_copy_D(mma, C);
9090

9191
/* Slice TiledCopy/TiledMMA operations to thread (work-item) level */
9292
auto thr_mma = mma.get_slice(local_id);

include/cute/atom/copy_traits_xe_2d.hpp

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -924,16 +924,26 @@ make_block_2d_copy_C(TiledMMA const& mma, // TiledMMA instance
924924
return make_block_2d_copy_C<ValType>(mma, gmem.stride()).with(gmem);
925925
}
926926

927-
template <class TiledMMA, class CopyOp, class GEngine, class GLayout>
927+
template <class TiledMMA, class GEngine, class GLayout>
928928
CUTE_HOST_DEVICE
929929
auto
930-
make_block_2d_copy_C(CopyOp const& op, // Copy operation
931-
TiledMMA const& mma, // TiledMMA instance
930+
make_block_2d_copy_D(TiledMMA const& mma, // TiledMMA instance
932931
Tensor<GEngine, GLayout> const& gmem) // Global tensor
933932
{
934933
static_assert(is_xe_block_2d_atom_v<CopyOp>, "Expected a block 2D atom");
935934
using ValType = typename GEngine::value_type;
936-
return make_block_2d_copy_C<ValType>(op, mma, gmem.stride()).with(gmem);
935+
return make_block_2d_copy_D<ValType>(mma, gmem.stride()).with(gmem);
936+
}
937+
938+
template <class TiledMMA, class CopyOp, class GEngine, class GLayout>
939+
CUTE_HOST_DEVICE
940+
auto
941+
make_block_2d_copy_CD(CopyOp const& op, // Copy operation
942+
TiledMMA const& mma, // TiledMMA instance
943+
Tensor<GEngine, GLayout> const& gmem) // Global tensor
944+
{
945+
using ValType = typename GEngine::value_type;
946+
return make_block_2d_copy_CD<ValType>(op, mma, gmem.stride()).with(gmem);
937947
}
938948

939949
template <class ValType, class TiledMMA, class... Strides>
@@ -942,32 +952,46 @@ auto
942952
make_block_2d_copy_C(TiledMMA const& mma, // TiledMMA instance
943953
Stride<Strides...> const& gstride) // Global memory strides
944954
{
945-
using MMAType = typename TiledMMA::ValTypeA;
955+
using MMAType = typename TiledMMA::ValTypeC;
946956
auto cC = make_identity_tensor(select<0,1>(mma.tile_mnk()));
947-
auto op = block_2d_selector<ValType, MMAType, true>(
957+
auto op = block_2d_selector<ValType, MMAType>(
948958
mma.get_slice(0).atom_partition_C(cC).layout(), gstride
949959
);
950-
return make_block_2d_copy_C<ValType>(op, mma, gstride);
960+
return make_block_2d_copy_CD<ValType>(op, mma, gstride);
951961
}
952962

953-
template <class ValType, class TiledMMA, class CopyOp, class... Strides>
963+
template <class ValType, class TiledMMA, class... Strides>
954964
CUTE_HOST_DEVICE
955965
auto
956-
make_block_2d_copy_C(CopyOp const& op, // Copy operation
957-
TiledMMA const& mma, // TiledMMA instance
966+
make_block_2d_copy_D(TiledMMA const& mma, // TiledMMA instance
958967
Stride<Strides...> const& gstride) // Global memory strides
959968
{
960-
return make_block_2d_copy_C<ValType>(op, mma, gstride, find_x_mode(gstride), find_y_mode(gstride));
969+
using MMAType = typename TiledMMA::ValTypeD;
970+
auto cD = make_identity_tensor(select<0,1>(mma.tile_mnk()));
971+
auto op = block_2d_selector<ValType, MMAType, true>(
972+
mma.get_slice(0).atom_partition_C(cD).layout(), gstride
973+
);
974+
return make_block_2d_copy_CD<ValType>(op, mma, gstride);
975+
}
976+
977+
template <class ValType, class TiledMMA, class CopyOp, class... Strides>
978+
CUTE_HOST_DEVICE
979+
auto
980+
make_block_2d_copy_CD(CopyOp const& op, // Copy operation
981+
TiledMMA const& mma, // TiledMMA instance
982+
Stride<Strides...> const& gstride) // Global memory strides
983+
{
984+
return make_block_2d_copy_CD<ValType>(op, mma, gstride, find_x_mode(gstride), find_y_mode(gstride));
961985
}
962986

963987
template <class ValType, class TiledMMA, class CopyOp, class... Strides, class XMode, class YMode>
964988
CUTE_HOST_DEVICE
965989
auto
966-
make_block_2d_copy_C(CopyOp const& op, // Copy operation
967-
TiledMMA const& mma, // TiledMMA instance
968-
Stride<Strides...> const& gstride, // Global memory strides
969-
XMode const& x_mode, // x, y modes
970-
YMode const& y_mode)
990+
make_block_2d_copy_CD(CopyOp const& op, // Copy operation
991+
TiledMMA const& mma, // TiledMMA instance
992+
Stride<Strides...> const& gstride, // Global memory strides
993+
XMode const& x_mode, // x, y modes
994+
YMode const& y_mode)
971995
{
972996
// Retrieve MMA atom's (subgroup, value) -> (M,N) layout
973997
auto tile_mn = select<0,1>(mma.tile_mnk());

media/docs/cpp/xe_rearchitecture.md

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ struct Copy_Traits</* Op */, XMode, YMode, ValType, TiledStrides>;
147147
148148
Since it can be a tricky to correctly choose block 2D parameters and set up an appropriate tiling, we introduce several helpers for creating TiledCopy objects.
149149
150-
The high-level APIs `make_block_2d_copy_{A,B,C}` automatically create TiledCopy objects for use with an existing `TiledMMA`. They choose the copy operation and trait template parameters heuristically.
150+
The high-level APIs `make_block_2d_copy_{A,B,C,D}` automatically create TiledCopy objects for use with an existing `TiledMMA`. They choose the copy operation and trait template parameters heuristically. Note that `make_block_2d_copy_C` and `make_block_2d_copy_D` only differ in their choice of a load (C) or store (D) operation.
151151
152152
```c++
153153
template <class Engine, class Layout, /*...*/>
@@ -167,6 +167,12 @@ CUTE_DEVICE
167167
TiledCopy<...>
168168
make_block_2d_copy_C(const TiledMMA<...>&,
169169
const Tensor<Engine, Layout>& gmem); // (M,N,...)
170+
171+
template <class Engine, class Layout, /*...*/>
172+
CUTE_DEVICE
173+
TiledCopy<...>
174+
make_block_2d_copy_D(const TiledMMA<...>&,
175+
const Tensor<Engine, Layout>& gmem); // (M,N,...)
170176
```
171177

172178
The user may also override the choice of copy operation:
@@ -179,7 +185,15 @@ make_block_2d_copy_A(CopyOp const& op, // Copy operation
179185
TiledMMA const& mma, // TiledMMA instance
180186
Tensor<GEngine, GLayout> const& gmem); // Global tensor
181187

182-
/* Similarly for B/C */
188+
/* Similarly for B */
189+
190+
/* Single routine for both C/D */
191+
template <class TiledMMA, class CopyOp, class GEngine, class GLayout>
192+
CUTE_HOST_DEVICE
193+
auto
194+
make_block_2d_copy_CD(CopyOp const& op, // Copy operation
195+
TiledMMA const& mma, // TiledMMA instance
196+
Tensor<GEngine, GLayout> const& gmem); // Global tensor
183197
```
184198
185199
The `make_block_2d_copy_*` family of functions create TiledCopy objects that match the scope of the TiledMMA. That is, the set of threads participating in the TiledMMA will also participate in the TiledCopy.
@@ -194,7 +208,7 @@ TiledCopy
194208
make_block_2d_copy(const CopyOp& op, const Tensor<Engine, Layout>& gmem);
195209
```
196210

197-
For advanced usage, there are additional overloads of `make_block_2d_copy` that allow more general work distributions for copies (see `include/cute/atom/copy_traits_xe_2d.hpp`).
211+
For advanced usage, there are additional overloads of `make_block_2d_copy` in which multiple subgroups participate (see `include/cute/atom/copy_traits_xe_2d.hpp`).
198212

199213
As the `CUTE_DEVICE` decorators imply, all the APIs above should be called from device code only, as they set up internal state that cannot be transferred from host to device.
200214

@@ -419,7 +433,7 @@ gemm_device(ATensor const& A, // (M,K)
419433
/* Create block 2D TiledCopies */
420434
auto copy_a = make_block_2d_copy_A(mma, A);
421435
auto copy_b = make_block_2d_copy_B(mma, B);
422-
auto copy_c = make_block_2d_copy_C(mma, C);
436+
auto copy_c = make_block_2d_copy_D(mma, C);
423437

424438
/* Slice TiledCopy/TiledMMA operations to thread (work-item) level */
425439
auto thr_mma = mma.get_slice(local_id);

0 commit comments

Comments
 (0)