Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions applications/flash_attention_v2/kernel/tile_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@

namespace cutlass::flash_attention {

struct XeFlashRowTile {
int bh; // = b * num_heads_q + h
int m_tile; // row tile index along Q
};

namespace kernel {

struct XeFlashIndividualTileScheduler {
Expand Down Expand Up @@ -92,6 +97,62 @@ struct XeFlashIndividualTileScheduler {
}
};

// Only schedule valid(non-empty) work groups.
struct XeFlashIndividualValidOnlyTileScheduler {
struct Params {
dim3 grid;
FastDivmod divmod_num_heads;
const XeFlashRowTile* tiles = nullptr;
int total_tiles = 0;
};

bool valid_ = true;
Params params;

CUTLASS_DEVICE
XeFlashIndividualValidOnlyTileScheduler(Params const& params) : params(params) {}

template<class ProblemSize, class TileShape>
static Params to_underlying_arguments(
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
TileShape const& tile_shape,
const XeFlashRowTile* tiles_dev, int total_tiles) {

using namespace cute;
// problem_size = [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo]
dim3 grid(size(ceil_div(shape<6>(problem_size), shape<1>(tile_shape))), total_tiles, 1);

return Params{grid, {shape<1>(problem_size)}, tiles_dev, total_tiles};
}

template <int Num_SGs>
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}

CUTLASS_DEVICE
bool is_valid() {
return valid_;
}

CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;

int t = BlockIdxY();
XeFlashRowTile tile = params.tiles[t];
int bidb = 0, bidh = 0;
params.divmod_num_heads(bidb, bidh, tile.bh);
return make_coord(BlockIdxX(), tile.m_tile, bidb, bidh);
}

CUTLASS_DEVICE
XeFlashIndividualValidOnlyTileScheduler& operator++() {
valid_ = false;
return *this;
}
};

struct XeFlashDecodeIndividualTileScheduler {

struct Params {
Expand Down Expand Up @@ -230,6 +291,7 @@ struct XeFlashPersistentTileScheduler {
} // namespace kernel

struct IndividualScheduler{};
struct IndividualValidOnlyScheduler{};
struct PersistentScheduler{};
struct FlashDecodeIndividualScheduler{};

Expand Down Expand Up @@ -267,6 +329,15 @@ struct XeFlashPersistentTileScheduler {
using Scheduler = kernel::XeFlashIndividualTileScheduler;
};

template <class ArchTag>
struct TileSchedulerSelector<
IndividualValidOnlyScheduler,
ArchTag,
cute::enable_if_t<cute::is_same_v<ArchTag, cutlass::arch::IntelXe>>>
{
using Scheduler = kernel::XeFlashIndividualValidOnlyTileScheduler;
};

template <class ArchTag>
struct TileSchedulerSelector<
PersistentScheduler,
Expand Down
23 changes: 18 additions & 5 deletions applications/flash_attention_v2/kernel/xe_flash_attn_prefill.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,12 @@ class FMHAPrefill {
using SoftmaxParams = typename CollectiveSoftmaxEpilogue::Params;

static_assert(cute::is_void_v<TileScheduler_> or cute::is_same_v<TileScheduler_, PersistentScheduler> or
cute::is_same_v<TileScheduler_, IndividualValidOnlyScheduler> or
cute::is_same_v<TileScheduler_, IndividualScheduler>, "Unsupported TileScheduler for Intel Xe.");
using TileSchedulerTag = TileScheduler_;
using TileScheduler = typename detail::TileSchedulerSelector<TileScheduler_, ArchTag>::Scheduler;
using TileSchedulerParams = typename TileScheduler::Params;
using XeFlashRowTile = cutlass::flash_attention::XeFlashRowTile;

// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
Expand Down Expand Up @@ -147,6 +149,8 @@ class FMHAPrefill {
SoftmaxArguments softmax{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
const XeFlashRowTile* tiles = nullptr;
int total_tiles = 0;
};

// Kernel entry point API
Expand All @@ -166,11 +170,19 @@ class FMHAPrefill {
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
static Params to_underlying_arguments(Arguments const &args, void *workspace) {
(void)workspace;
return {args.mode, args.problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace),
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, TileShapeOutput{})};
if constexpr (cute::is_same_v<TileScheduler_, IndividualValidOnlyScheduler>) {
return {args.mode, args.problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace),
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, TileShapeOutput{}, args.tiles, args.total_tiles)};
} else {
return {args.mode, args.problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace),
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, TileShapeOutput{})};
}
}

static bool can_implement(Arguments const &args) {
Expand Down Expand Up @@ -224,6 +236,7 @@ class FMHAPrefill {

TileScheduler tile_scheduler{params.scheduler};


CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); // head_size_blk_idx, seq_len_blk_idx, batch_blk_idx, num_heads_blk_idx
Expand Down
Loading
Loading