Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mla use tensor core #9952

Draft
wants to merge 9 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions csrc/gpu/append_attn/append_attention_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,3 @@ void CascadeAppendAttentionKernel(
"cache_int4_zp]");
}
}

inline uint32_t get_max_partition_size(int bsz) {
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
static const uint32_t max_partition_size =
max_partition_size_env == nullptr ? 0 : std::stoul(std::string(max_partition_size_env));
return (max_partition_size != 0 ? max_partition_size : (bsz == 1 ? 128 : 512));
}
123 changes: 102 additions & 21 deletions csrc/gpu/append_attn/get_block_shape_and_split_kv_block.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "helper.h"
#include "paddle/extension.h"
#include "utils.cuh"
#include "cute/tensor.hpp"

template <typename T>
inline __device__ __host__ T div_up(T m, T n) {
Expand Down Expand Up @@ -48,6 +50,49 @@ __global__ void split_q_block(const int* __restrict__ seq_lens_q,
}
}

__global__ void split_q_block_mla(const int * __restrict__ seq_lens_q,
const int * __restrict__ seq_lens_encoder,
const int * __restrict__ seq_lens_decoder,
int * __restrict__ batch_ids,
int * __restrict__ tile_ids_per_batch,
int * __restrict__ num_blocks_x,
const int bsz,
const int num_rows_per_block,
const int chunk_size,
const int GROUP_SIZE,
const bool is_encoder) {
if (threadIdx.x == 0) {
int gridx = 0;
int index = 0;
for (uint32_t bid = 0; bid < bsz; bid++) {
int seq_len = seq_lens_q[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;

if (seq_len == 0) continue;

int loop_times;
if (is_encoder) {
loop_times = cute::ceil_div(seq_len * GROUP_SIZE, num_rows_per_block);
if (seq_len_decoder > 0) {
loop_times = 0;
}
} else {
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
if (seq_len_encoder > 0) {
loop_times = 0;
}
}
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
batch_ids[index] = bid;
tile_ids_per_batch[index++] = tile_id;
}
gridx += loop_times;
}
*num_blocks_x = gridx;
}
}

__global__ void split_kv_block(const int* __restrict__ seq_lens_decoder,
const int* __restrict__ seq_lens_encoder,
int* __restrict__ batch_ids,
Expand Down Expand Up @@ -110,7 +155,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const int decoder_step_token_num) {
paddle::Tensor encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_x_cpu,
kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks_x_cpu, decoder_batch_ids,
decoder_tile_ids_per_batch, decoder_num_blocks_x_cpu;
decoder_tile_ids_per_batch, decoder_num_blocks_x, decoder_num_blocks_x_cpu;
auto stream = seq_lens_this_time.stream();
int bsz = cum_offsets.shape()[0];
const int encoder_block_shape_q = get_encoder_block_shape_q();
Expand All @@ -131,33 +176,65 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
// decoder
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
if (max_dec_len_this_time_data > 0) {
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
decoder_batch_ids =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
auto decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
const bool mla_use_tensorcore = get_mla_use_tensorcore();
if (mla_use_tensorcore) {
const int chunk_size = get_max_partition_size(bsz);
const int decoder_max_tile_size_per_bs = div_up(max_len_kv_cpu.data<int>()[0], chunk_size);
decoder_batch_ids =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs},
paddle::DataType::INT32,
seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs},
paddle::DataType::INT32,
seq_lens_encoder.place());
decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block_mla<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
block_size,
chunk_size,
group_size,
false // is_encoder
);
} else {
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
decoder_batch_ids =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
}

decoder_num_blocks_x_cpu =
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
decoder_batch_ids =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace());
decoder_tile_ids_per_batch =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace());
decoder_num_blocks_x =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace());
decoder_num_blocks_x_cpu =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::CPUPlace());
}
Expand Down Expand Up @@ -230,6 +307,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
kv_num_blocks_x_cpu, /*cpu*/
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks_x,
decoder_num_blocks_x_cpu, /*cpu*/
max_len_kv_cpu /*cpu*/};
}
Expand All @@ -250,6 +328,7 @@ std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32};
}

Expand All @@ -271,6 +350,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
dynamic_shape,
dynamic_shape,
{1},
{1},
{1}};
}

Expand All @@ -290,6 +370,7 @@ PD_BUILD_OP(get_block_shape_and_split_kv_block)
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"decoder_num_blocks_cpu",
"max_len_kv"})
.Attrs({"group_size: int",
"block_size: int",
Expand Down
29 changes: 0 additions & 29 deletions csrc/gpu/append_attn/multi_head_latent_attention_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,3 @@ void DecodeMLAAttentionKernel(
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);

inline uint32_t get_max_partition_size(int bsz) {
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
static const uint32_t max_partition_size =
max_partition_size_env == nullptr ? 0 : std::stoul(std::string(max_partition_size_env));
return (max_partition_size != 0 ? max_partition_size : (bsz == 1 ? 128 : 512));
}


inline uint32_t get_cascade_attention_deal_each_time() {
static const char* cascade_attention_deal_each_time_env = std::getenv("FLAGS_cascade_attention_deal_each_time");
static const uint32_t cascade_attention_deal_each_time =
cascade_attention_deal_each_time_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_deal_each_time_env));
return (cascade_attention_deal_each_time != 0 ? cascade_attention_deal_each_time : 32);
}

inline uint32_t get_cascade_attention_num_stages() {
static const char* cascade_attention_num_stages_env = std::getenv("FLAGS_cascade_attention_num_stages");
static const uint32_t cascade_attention_num_stages =
cascade_attention_num_stages_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_stages_env));
return cascade_attention_num_stages != 0 ? cascade_attention_num_stages : 2;
}

inline uint32_t get_cascade_attention_num_threads() {
static const char* cascade_attention_num_threads_env = std::getenv("FLAGS_cascade_attention_num_threads");
static const uint32_t cascade_attention_num_threads =
cascade_attention_num_threads_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_threads_env));
return cascade_attention_num_threads != 0 ? cascade_attention_num_threads : 128;
}
5 changes: 5 additions & 0 deletions csrc/gpu/append_attn/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,11 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
__VA_ARGS__ \
break; \
} \
case 192: { \
constexpr size_t HEAD_DIM = 192; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("not support the head_dim: ", head_dim); \
} \
Expand Down
71 changes: 71 additions & 0 deletions csrc/gpu/env.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

inline uint32_t get_decoder_block_shape_q() {
static const char* decoder_block_shape_q_env = std::getenv("FLAGS_dec_block_shape_q");
static const uint32_t decoder_block_shape_q =
decoder_block_shape_q_env == nullptr ? 16 : std::stoi(std::string(decoder_block_shape_q_env));
return decoder_block_shape_q;
}

inline uint32_t get_encoder_block_shape_q() {
static const char* encoder_block_shape_q_env = std::getenv("FLAGS_enc_block_shape_q");
static const uint32_t encoder_block_shape_q =
encoder_block_shape_q_env == nullptr ? 64 : std::stoi(std::string(encoder_block_shape_q_env));
return encoder_block_shape_q;
}

inline uint32_t get_max_partition_size(int bsz) {
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
static const uint32_t max_partition_size =
max_partition_size_env == nullptr ? 0 : std::stoul(std::string(max_partition_size_env));
return (max_partition_size != 0 ? max_partition_size : (bsz == 1 ? 128 : 512));
}

inline uint32_t get_cascade_attention_deal_each_time() {
static const char* cascade_attention_deal_each_time_env = std::getenv("FLAGS_cascade_attention_deal_each_time");
static const uint32_t cascade_attention_deal_each_time =
cascade_attention_deal_each_time_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_deal_each_time_env));
return (cascade_attention_deal_each_time != 0 ? cascade_attention_deal_each_time : 32);
}

inline uint32_t get_cascade_attention_num_stages() {
static const char* cascade_attention_num_stages_env = std::getenv("FLAGS_cascade_attention_num_stages");
static const uint32_t cascade_attention_num_stages =
cascade_attention_num_stages_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_stages_env));
return cascade_attention_num_stages != 0 ? cascade_attention_num_stages : 2;
}

inline uint32_t get_cascade_attention_num_threads() {
static const char* cascade_attention_num_threads_env = std::getenv("FLAGS_cascade_attention_num_threads");
static const uint32_t cascade_attention_num_threads =
cascade_attention_num_threads_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_threads_env));
return cascade_attention_num_threads != 0 ? cascade_attention_num_threads : 128;
}

inline bool get_mla_use_tensorcore() {
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
static const uint32_t mla_use_tensorcore =
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
return mla_use_tensorcore != 0 ? true : false;
}

inline bool enable_cuda_core_fp8_gemm() {
static const char* enable_cuda_core_fp8_env = std::getenv("FLAGS_cuda_core_fp8_gemm");
static const bool enable_cuda_core_fp8_gemm =
enable_cuda_core_fp8_env != nullptr && std::string(enable_cuda_core_fp8_env) == "1";
return enable_cuda_core_fp8_gemm;
}
7 changes: 0 additions & 7 deletions csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,5 @@ typedef struct {
cudaStream_t stream;
} GemmParams;

inline bool enable_cuda_core_fp8_gemm() {
static const char* enable_cuda_core_fp8_env = std::getenv("FLAGS_cuda_core_fp8_gemm");
static const bool enable_cuda_core_fp8_gemm =
enable_cuda_core_fp8_env != nullptr && std::string(enable_cuda_core_fp8_env) == "1";
return enable_cuda_core_fp8_gemm;
}

template <typename InputType, typename OutputType>
bool cuda_core_gemm_launcher(GemmParams const& params);
18 changes: 6 additions & 12 deletions csrc/gpu/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ namespace cub = hipcub;
#include <iostream>
#include <fstream>

#include "env.h"
#include "paddle/extension.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "nlohmann/json.hpp"

using json = nlohmann::json;
Expand Down Expand Up @@ -222,16 +224,8 @@ __device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids, int l
return flag;
}

inline uint32_t get_decoder_block_shape_q() {
static const char* decoder_block_shape_q_env = std::getenv("FLAGS_dec_block_shape_q");
static const uint32_t decoder_block_shape_q =
decoder_block_shape_q_env == nullptr ? 16 : std::stoi(std::string(decoder_block_shape_q_env));
return decoder_block_shape_q;
}

inline uint32_t get_encoder_block_shape_q() {
static const char* encoder_block_shape_q_env = std::getenv("FLAGS_enc_block_shape_q");
static const uint32_t encoder_block_shape_q =
encoder_block_shape_q_env == nullptr ? 64 : std::stoi(std::string(encoder_block_shape_q_env));
return encoder_block_shape_q;
inline int GetSMVersion() {
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
phi::backends::gpu::GetCurrentDeviceId());
return sm_version;
}
Loading
Loading