You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[CuTe] [Xe] Separate make_block_2d_copy_{C,D} APIs for loads/stores (#572)
Currently `make_block_2d_copy_C` always selects a block 2D store, but we
really need separate APIs for C (loads) and D (stores). The default
value type can also be different, in case of MMA atoms with different
C/D types.
This PR introduces C/D APIs. Some APIs are common between C/D, and are
named `make_block_2d_copy_CD` to avoid duplication.
Since it can be a tricky to correctly choose block 2D parameters and set up an appropriate tiling, we introduce several helpers for creating TiledCopy objects.
149
149
150
-
The high-level APIs `make_block_2d_copy_{A,B,C}` automatically create TiledCopy objects for use with an existing `TiledMMA`. They choose the copy operation and trait template parameters heuristically.
150
+
The high-level APIs `make_block_2d_copy_{A,B,C,D}` automatically create TiledCopy objects for use with an existing `TiledMMA`. They choose the copy operation and trait template parameters heuristically. Note that `make_block_2d_copy_C` and `make_block_2d_copy_D` only differ in their choice of a load (C) or store (D) operation.
151
151
152
152
```c++
153
153
template <class Engine, class Layout, /*...*/>
@@ -167,6 +167,12 @@ CUTE_DEVICE
167
167
TiledCopy<...>
168
168
make_block_2d_copy_C(const TiledMMA<...>&,
169
169
const Tensor<Engine, Layout>& gmem); // (M,N,...)
170
+
171
+
template <class Engine, class Layout, /*...*/>
172
+
CUTE_DEVICE
173
+
TiledCopy<...>
174
+
make_block_2d_copy_D(const TiledMMA<...>&,
175
+
const Tensor<Engine, Layout>& gmem); // (M,N,...)
170
176
```
171
177
172
178
The user may also override the choice of copy operation:
Tensor<GEngine, GLayout> const& gmem); // Global tensor
183
197
```
184
198
185
199
The `make_block_2d_copy_*` family of functions create TiledCopy objects that match the scope of the TiledMMA. That is, the set of threads participating in the TiledMMA will also participate in the TiledCopy.
For advanced usage, there are additional overloads of `make_block_2d_copy`that allow more general work distributions for copies (see `include/cute/atom/copy_traits_xe_2d.hpp`).
211
+
For advanced usage, there are additional overloads of `make_block_2d_copy`in which multiple subgroups participate (see `include/cute/atom/copy_traits_xe_2d.hpp`).
198
212
199
213
As the `CUTE_DEVICE` decorators imply, all the APIs above should be called from device code only, as they set up internal state that cannot be transferred from host to device.
200
214
@@ -419,7 +433,7 @@ gemm_device(ATensor const& A, // (M,K)
419
433
/* Create block 2D TiledCopies */
420
434
auto copy_a = make_block_2d_copy_A(mma, A);
421
435
auto copy_b = make_block_2d_copy_B(mma, B);
422
-
auto copy_c = make_block_2d_copy_C(mma, C);
436
+
auto copy_c = make_block_2d_copy_D(mma, C);
423
437
424
438
/* Slice TiledCopy/TiledMMA operations to thread (work-item) level */
0 commit comments