Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 158 additions & 60 deletions csrc/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,68 +131,108 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32;
int device;
cudaGetDevice(&device);

// Get architecture-specific optimization configuration
auto arch_config = get_arch_optimization_config(device, params.seqlen_q, params.seqlen_k, params.b);

int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// 2 * (...) - Double buffering factor
// (3 * kBlockM + 2 * kBlockN) * Headdim - Vector tiles in shared memory
// - 3 * kBlockM * Headdim: Q tile, dQ tile, dOut tile
// - 2 * kBlockN * Headdim: K tile, V tile
// 4 * kBlockM * kBlockN - Matrix tiles in shared memory
// - 2 * kBlockM * kBlockN: S tile, P tile
// - 2 * kBlockM * kBlockN: Mask tile, Bias tile

// 2 * (...) - Double buffering factor
// (3 * kBlockM + 2 * kBlockN) * Headdim - Vector tiles in shared memory
// - 3 * kBlockM * Headdim: Q tile, dQ tile, dOut tile
// - 2 * kBlockN * Headdim: K tile, V tile
// 4 * kBlockM * kBlockN - Matrix tiles in shared memory
// - 2 * kBlockM * kBlockN: S tile, P tile
// - 2 * kBlockM * kBlockN: Mask tile, Bias tile

// Architecture-specific optimization for head dim 32
if (supports_sm90_features(device)) { // SM 9.0 (H100/H200)
if (max_smem_per_block >= 2 * ((3 * 64 + 2 * 128) * Headdim + 4 * 64 * 128)) { // 94 KB
// H100 can handle larger blocks with small head dim
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
}
} else if (supports_sm89_features(device)) { // SM 8.9 (Ada/H200)
if (max_smem_per_block >= 2 * ((3 * 64 + 2 * 128) * Headdim + 4 * 64 * 128)) {
// Ada optimization for small head dimensions
if (params.seqlen_q >= 4096) { // Long sequences benefit from larger blocks
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
}
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
}
} else { // SM 8.6 and below
if (max_smem_per_block >= 2 * ((3 * 64 + 2 * 128) * Headdim + 4 * 64 * 128)) { // 94 KB
// We can afford more registers to keep V in registers
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
} else { // 96 KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
}
}
}

template<typename T, bool Is_causal>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
int device;
cudaGetDevice(&device);

// Get architecture-specific optimization configuration
auto arch_config = get_arch_optimization_config(device, params.seqlen_q, params.seqlen_k, params.b);

int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>>(params, stream);
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
if (max_smem_per_block >= 144 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
// This has a lot of register spilling
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>>(params, stream);
} else {
// if (params.h == params.h_k) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
// } else {
// }

const char* optimization_choice = nullptr;

// Architecture-specific optimization selection for head dim 64
if (supports_sm90_features(device)) { // SM 9.0 (H100/H200)
if (max_smem_per_block >= 144 * 1024) {
// Use large block sizes for optimal bandwidth utilization on H100
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
optimization_choice = "SM90_LargeBlock_128x128";
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
optimization_choice = "SM90_MediumBlock_64x128_VinRegs";
}
} else if (supports_sm89_features(device)) { // SM 8.9 (Ada/H200)
if (max_smem_per_block >= 144 * 1024) {
// Optimized configuration for Ada architecture
if (params.seqlen_q >= 8192) { // Long sequences
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
optimization_choice = "SM89_VeryLongSeq_128x128";
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
optimization_choice = "SM89_StandardSeq_64x128_VinRegs";
}
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
optimization_choice = "SM89_LowMem_64x128_VinRegs";
}
} else { // SM 8.6 and below (A100, etc.)
if (max_smem_per_block >= 144 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
optimization_choice = "SM86_Standard_128x128";
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
optimization_choice = "SM86_LowMem_64x128_VinRegs";
}
}
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);

// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);

// Log optimization choice for profiling
log_backward_optimization_choice("64", device, params.seqlen_q, params.seqlen_k, params.b, optimization_choice);
}

template<typename T, bool Is_causal>
Expand Down Expand Up @@ -220,31 +260,55 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
int device;
cudaGetDevice(&device);

// Get architecture-specific optimization configuration
auto arch_config = get_arch_optimization_config(device, params.seqlen_q, params.seqlen_k, params.b);

int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
if (max_smem_per_block >= 144 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>>(params, stream);
} else {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_causal>(params, stream);

const char* optimization_choice = nullptr;

// Architecture-specific optimization selection
if (supports_sm90_features(device)) { // SM 9.0 (H100/H200)
if (max_smem_per_block >= 176 * 1024) {
// Use large block sizes for optimal memory bandwidth on H100
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
optimization_choice = "SM90_LargeBlock_128x128";
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 4, 2, false, false, T>, Is_causal>(params, stream);
optimization_choice = "SM90_MediumBlock_64x128";
}
} else if (supports_sm89_features(device)) { // SM 8.9 (Ada/H200)
if (max_smem_per_block >= 144 * 1024) {
// Optimized for variable sequence lengths on Ada
if (params.seqlen_q >= 4096) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_causal>(params, stream);
optimization_choice = "SM89_LongSeq_128x64";
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream);
optimization_choice = "SM89_StandardSeq_64x128";
}
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_causal>(params, stream);
optimization_choice = "SM89_LowMem_64x64_VinRegs";
}
} else { // SM 8.6 and below (A100, etc.)
if (max_smem_per_block >= 144 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream);
optimization_choice = "SM86_Standard_64x128";
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_causal>(params, stream);
optimization_choice = "SM86_LowMem_64x64_VinRegs";
}
}
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);

// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);

// Log optimization choice for profiling
log_backward_optimization_choice("128", device, params.seqlen_q, params.seqlen_k, params.b, optimization_choice);
}

template<typename T, bool Is_causal>
Expand All @@ -270,18 +334,52 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
int device;
cudaGetDevice(&device);

// Get architecture-specific optimization configuration
auto arch_config = get_arch_optimization_config(device, params.seqlen_q, params.seqlen_k, params.b);

int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 176 * 1024) { // H100
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
} else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
} else { // sm86 and sm89, max smem is 99 KB. V in regs and no double buffering.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal>(params, stream);

// Architecture-specific optimization selection for head dim 256
if (supports_sm90_features(device)) { // SM 9.0 (H100/H200)
if (max_smem_per_block >= 176 * 1024) {
// H100 with ample shared memory - use optimal configuration
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
} else if (max_smem_per_block >= 144 * 1024) {
// Moderate shared memory - disable double buffering
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
} else {
// Memory constrained - use registers and single buffering
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal>(params, stream);
}
} else if (supports_sm89_features(device)) { // SM 8.9 (Ada/H200)
if (max_smem_per_block >= 176 * 1024) {
// Ada with high memory - optimize for long sequences
if (params.seqlen_q >= 8192) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
}
} else if (max_smem_per_block >= 144 * 1024) {
// Ada with moderate memory
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
} else {
// Memory constrained Ada
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal>(params, stream);
}
} else { // SM 8.6 and below (A100, etc.)
if (max_smem_per_block >= 176 * 1024) { // H100
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
} else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
} else { // sm86 and sm89, max smem is 99 KB. V in regs and no double buffering.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal>(params, stream);
}
}
}

Expand Down
Loading