Skip to content

Commit

Permalink
fix confilt
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Feb 28, 2025
2 parents d10b523 + cc725e3 commit c92d91f
Show file tree
Hide file tree
Showing 30 changed files with 4,481 additions and 103 deletions.
7 changes: 0 additions & 7 deletions csrc/gpu/append_attn/append_attention_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,3 @@ void CascadeAppendAttentionKernel(
"cache_int4_zp]");
}
}

inline uint32_t get_max_partition_size(int bsz) {
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
static const uint32_t max_partition_size =
max_partition_size_env == nullptr ? 0 : std::stoul(std::string(max_partition_size_env));
return (max_partition_size != 0 ? max_partition_size : (bsz == 1 ? 128 : 512));
}
123 changes: 102 additions & 21 deletions csrc/gpu/append_attn/get_block_shape_and_split_kv_block.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "helper.h"
#include "paddle/extension.h"
#include "utils.cuh"
#include "cute/tensor.hpp"

template <typename T>
inline __device__ __host__ T div_up(T m, T n) {
Expand Down Expand Up @@ -48,6 +50,49 @@ __global__ void split_q_block(const int* __restrict__ seq_lens_q,
}
}

__global__ void split_q_block_mla(const int * __restrict__ seq_lens_q,
const int * __restrict__ seq_lens_encoder,
const int * __restrict__ seq_lens_decoder,
int * __restrict__ batch_ids,
int * __restrict__ tile_ids_per_batch,
int * __restrict__ num_blocks_x,
const int bsz,
const int num_rows_per_block,
const int chunk_size,
const int GROUP_SIZE,
const bool is_encoder) {
if (threadIdx.x == 0) {
int gridx = 0;
int index = 0;
for (uint32_t bid = 0; bid < bsz; bid++) {
int seq_len = seq_lens_q[bid];
int seq_len_encoder = seq_lens_encoder[bid];
int seq_len_decoder = seq_lens_decoder[bid] + seq_len;

if (seq_len == 0) continue;

int loop_times;
if (is_encoder) {
loop_times = cute::ceil_div(seq_len * GROUP_SIZE, num_rows_per_block);
if (seq_len_decoder > 0) {
loop_times = 0;
}
} else {
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
if (seq_len_encoder > 0) {
loop_times = 0;
}
}
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
batch_ids[index] = bid;
tile_ids_per_batch[index++] = tile_id;
}
gridx += loop_times;
}
*num_blocks_x = gridx;
}
}

__global__ void split_kv_block(const int* __restrict__ seq_lens_decoder,
const int* __restrict__ seq_lens_encoder,
int* __restrict__ batch_ids,
Expand Down Expand Up @@ -110,7 +155,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const int decoder_step_token_num) {
paddle::Tensor encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks_x_cpu,
kv_batch_ids, kv_tile_ids_per_batch, kv_num_blocks_x_cpu, decoder_batch_ids,
decoder_tile_ids_per_batch, decoder_num_blocks_x_cpu;
decoder_tile_ids_per_batch, decoder_num_blocks_x, decoder_num_blocks_x_cpu;
auto stream = seq_lens_this_time.stream();
int bsz = cum_offsets.shape()[0];
const int encoder_block_shape_q = get_encoder_block_shape_q();
Expand All @@ -131,33 +176,65 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
// decoder
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
if (max_dec_len_this_time_data > 0) {
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
decoder_batch_ids =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
auto decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
const bool mla_use_tensorcore = get_mla_use_tensorcore();
if (mla_use_tensorcore) {
const int chunk_size = get_max_partition_size(bsz);
const int decoder_max_tile_size_per_bs = div_up(max_len_kv_cpu.data<int>()[0], chunk_size);
decoder_batch_ids =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs},
paddle::DataType::INT32,
seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs},
paddle::DataType::INT32,
seq_lens_encoder.place());
decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block_mla<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
block_size,
chunk_size,
group_size,
false // is_encoder
);
} else {
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
decoder_batch_ids =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32,
seq_lens_encoder.place());
decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
}

decoder_num_blocks_x_cpu =
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
decoder_batch_ids =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace());
decoder_tile_ids_per_batch =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace());
decoder_num_blocks_x =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::GPUPlace());
decoder_num_blocks_x_cpu =
paddle::full({1}, -1, paddle::DataType::INT32, paddle::CPUPlace());
}
Expand Down Expand Up @@ -230,6 +307,7 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
kv_num_blocks_x_cpu, /*cpu*/
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks_x,
decoder_num_blocks_x_cpu, /*cpu*/
max_len_kv_cpu /*cpu*/};
}
Expand All @@ -250,6 +328,7 @@ std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32,
paddle::DataType::INT32};
}

Expand All @@ -271,6 +350,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
dynamic_shape,
dynamic_shape,
{1},
{1},
{1}};
}

Expand All @@ -290,6 +370,7 @@ PD_BUILD_OP(get_block_shape_and_split_kv_block)
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks",
"decoder_num_blocks_cpu",
"max_len_kv"})
.Attrs({"group_size: int",
"block_size: int",
Expand Down
29 changes: 0 additions & 29 deletions csrc/gpu/append_attn/multi_head_latent_attention_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,3 @@ void DecodeMLAAttentionKernel(
bool causal,
cudaStream_t &stream,
paddle::Tensor *out);

inline uint32_t get_max_partition_size(int bsz) {
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
static const uint32_t max_partition_size =
max_partition_size_env == nullptr ? 0 : std::stoul(std::string(max_partition_size_env));
return (max_partition_size != 0 ? max_partition_size : (bsz == 1 ? 128 : 512));
}


inline uint32_t get_cascade_attention_deal_each_time() {
static const char* cascade_attention_deal_each_time_env = std::getenv("FLAGS_cascade_attention_deal_each_time");
static const uint32_t cascade_attention_deal_each_time =
cascade_attention_deal_each_time_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_deal_each_time_env));
return (cascade_attention_deal_each_time != 0 ? cascade_attention_deal_each_time : 32);
}

inline uint32_t get_cascade_attention_num_stages() {
static const char* cascade_attention_num_stages_env = std::getenv("FLAGS_cascade_attention_num_stages");
static const uint32_t cascade_attention_num_stages =
cascade_attention_num_stages_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_stages_env));
return cascade_attention_num_stages != 0 ? cascade_attention_num_stages : 2;
}

inline uint32_t get_cascade_attention_num_threads() {
static const char* cascade_attention_num_threads_env = std::getenv("FLAGS_cascade_attention_num_threads");
static const uint32_t cascade_attention_num_threads =
cascade_attention_num_threads_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_threads_env));
return cascade_attention_num_threads != 0 ? cascade_attention_num_threads : 128;
}
5 changes: 5 additions & 0 deletions csrc/gpu/append_attn/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,11 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
__VA_ARGS__ \
break; \
} \
case 192: { \
constexpr size_t HEAD_DIM = 192; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("not support the head_dim: ", head_dim); \
} \
Expand Down
71 changes: 71 additions & 0 deletions csrc/gpu/env.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

inline uint32_t get_decoder_block_shape_q() {
static const char* decoder_block_shape_q_env = std::getenv("FLAGS_dec_block_shape_q");
static const uint32_t decoder_block_shape_q =
decoder_block_shape_q_env == nullptr ? 16 : std::stoi(std::string(decoder_block_shape_q_env));
return decoder_block_shape_q;
}

inline uint32_t get_encoder_block_shape_q() {
static const char* encoder_block_shape_q_env = std::getenv("FLAGS_enc_block_shape_q");
static const uint32_t encoder_block_shape_q =
encoder_block_shape_q_env == nullptr ? 64 : std::stoi(std::string(encoder_block_shape_q_env));
return encoder_block_shape_q;
}

inline uint32_t get_max_partition_size(int bsz) {
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
static const uint32_t max_partition_size =
max_partition_size_env == nullptr ? 0 : std::stoul(std::string(max_partition_size_env));
return (max_partition_size != 0 ? max_partition_size : (bsz == 1 ? 128 : 512));
}

inline uint32_t get_cascade_attention_deal_each_time() {
static const char* cascade_attention_deal_each_time_env = std::getenv("FLAGS_cascade_attention_deal_each_time");
static const uint32_t cascade_attention_deal_each_time =
cascade_attention_deal_each_time_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_deal_each_time_env));
return (cascade_attention_deal_each_time != 0 ? cascade_attention_deal_each_time : 32);
}

inline uint32_t get_cascade_attention_num_stages() {
static const char* cascade_attention_num_stages_env = std::getenv("FLAGS_cascade_attention_num_stages");
static const uint32_t cascade_attention_num_stages =
cascade_attention_num_stages_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_stages_env));
return cascade_attention_num_stages != 0 ? cascade_attention_num_stages : 2;
}

inline uint32_t get_cascade_attention_num_threads() {
static const char* cascade_attention_num_threads_env = std::getenv("FLAGS_cascade_attention_num_threads");
static const uint32_t cascade_attention_num_threads =
cascade_attention_num_threads_env == nullptr ? 0 : std::stoul(std::string(cascade_attention_num_threads_env));
return cascade_attention_num_threads != 0 ? cascade_attention_num_threads : 128;
}

inline bool get_mla_use_tensorcore() {
static const char* mla_use_tensorcore_env = std::getenv("FLAGS_mla_use_tensorcore");
static const uint32_t mla_use_tensorcore =
mla_use_tensorcore_env == nullptr ? 1 : std::stoul(std::string(mla_use_tensorcore_env));
return mla_use_tensorcore != 0 ? true : false;
}

inline bool enable_cuda_core_fp8_gemm() {
static const char* enable_cuda_core_fp8_env = std::getenv("FLAGS_cuda_core_fp8_gemm");
static const bool enable_cuda_core_fp8_gemm =
enable_cuda_core_fp8_env != nullptr && std::string(enable_cuda_core_fp8_env) == "1";
return enable_cuda_core_fp8_gemm;
}
7 changes: 0 additions & 7 deletions csrc/gpu/fp8_gemm_with_cutlass/fp8_fp8_half_cuda_core_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,5 @@ typedef struct {
cudaStream_t stream;
} GemmParams;

inline bool enable_cuda_core_fp8_gemm() {
static const char* enable_cuda_core_fp8_env = std::getenv("FLAGS_cuda_core_fp8_gemm");
static const bool enable_cuda_core_fp8_gemm =
enable_cuda_core_fp8_env != nullptr && std::string(enable_cuda_core_fp8_env) == "1";
return enable_cuda_core_fp8_gemm;
}

template <typename InputType, typename OutputType>
bool cuda_core_gemm_launcher(GemmParams const& params);
18 changes: 6 additions & 12 deletions csrc/gpu/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ namespace cub = hipcub;
#include <iostream>
#include <fstream>

#include "env.h"
#include "paddle/extension.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "nlohmann/json.hpp"

using json = nlohmann::json;
Expand Down Expand Up @@ -222,16 +224,8 @@ __device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids, int l
return flag;
}

inline uint32_t get_decoder_block_shape_q() {
static const char* decoder_block_shape_q_env = std::getenv("FLAGS_dec_block_shape_q");
static const uint32_t decoder_block_shape_q =
decoder_block_shape_q_env == nullptr ? 16 : std::stoi(std::string(decoder_block_shape_q_env));
return decoder_block_shape_q;
}

inline uint32_t get_encoder_block_shape_q() {
static const char* encoder_block_shape_q_env = std::getenv("FLAGS_enc_block_shape_q");
static const uint32_t encoder_block_shape_q =
encoder_block_shape_q_env == nullptr ? 64 : std::stoi(std::string(encoder_block_shape_q_env));
return encoder_block_shape_q;
inline int GetSMVersion() {
static int sm_version = phi::backends::gpu::GetGPUComputeCapability(
phi::backends::gpu::GetCurrentDeviceId());
return sm_version;
}
Loading

0 comments on commit c92d91f

Please sign in to comment.