Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support XPU deepseek #9917

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions csrc/xpu/src/adjust_batch.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include <xft/xdnn_plugin.h>
namespace xftkernel = baidu::xpu::xftkernel;
std::vector<paddle::Tensor> AdjustBatch(const paddle::Tensor& tmp_out, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::optional<paddle::Tensor>& output_padding_offset,
int max_input_length) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
xpu::ctx_guard RAII_GUARD(xpu_ctx->x_context());
using XPUType = typename XPUTypeTrait<float16>::Type; // only support fp16
typedef paddle::float16 data_t;
const int token_num = tmp_out.dims()[0];
const int dim = tmp_out.dims()[1];
const int bsz = cum_offsets.shape()[0];

std::vector<int> seq_lens_encoder_cpu(bsz, 0);
std::vector<int> seq_lens_decoder_cpu(bsz, 0);
std::vector<int> encoder_batch_idx; // 去除空隙的batch map
std::vector<int> decoder_batch_idx; // 去除空隙的batch map
std::vector<int> encoder_seq_lod;
int r = xpu_memcpy(seq_lens_encoder_cpu.data(),
seq_lens_encoder.data<int>(),
sizeof(int32_t) * bsz,
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
r = xpu_memcpy(seq_lens_decoder_cpu.data(),
seq_lens_decoder.data<int>(),
sizeof(int32_t) * bsz,
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
int enc_batch = 0, dec_batch = 0;
int batch_offset = 0;
encoder_seq_lod.push_back(0);
for(int i = 0; i < bsz; ++i){
if(seq_lens_encoder_cpu[i] > 0){
enc_batch++;
encoder_batch_idx.push_back(i - batch_offset);
encoder_seq_lod.push_back(seq_lens_encoder_cpu[i]);
encoder_seq_lod[enc_batch] += encoder_seq_lod[enc_batch - 1];
}
else if(seq_lens_decoder_cpu[i] > 0){
dec_batch++;
decoder_batch_idx.push_back(i - batch_offset);
}
else{
batch_offset++;
}
}
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp =
baidu::xpu::api::VectorParam<int32_t>{encoder_seq_lod.data(), enc_batch + 1, nullptr}
.to_xpu(RAII_GUARD);
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp =
baidu::xpu::api::VectorParam<int32_t>{encoder_batch_idx.data(), enc_batch, nullptr}
.to_xpu(RAII_GUARD);
baidu::xpu::api::VectorParam<int32_t> decoder_batch_map_vp =
baidu::xpu::api::VectorParam<int32_t>{decoder_batch_idx.data(), dec_batch, nullptr}
.to_xpu(RAII_GUARD);
auto out = paddle::full({token_num, dim}, -2, tmp_out.type(), tmp_out.place());

r = xftkernel::xft_eb_adjust_batch<XPUType, XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<const XPUType*>(tmp_out.data<data_t>()),
reinterpret_cast<XPUType*>(out.data<data_t>()),
encoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
return {out};
}

std::vector<std::vector<int64_t>> AdjustBatchInferShape(const std::vector<int64_t>& tmp_out_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_encoder_shape,
const paddle::optional<std::vector<int64_t>>& output_padding_offset_shape) {
if (output_padding_offset_shape) {
PD_THROW("speculative decoding is not supported in XPU.");
}
int64_t token_num = tmp_out_shape[0];
int64_t dim_embed = tmp_out_shape[1];
return {{token_num, dim_embed}};
}

std::vector<paddle::DataType> AdjustBatchInferDtype(const paddle::DataType& tmp_out_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::optional<paddle::DataType>& output_padding_offset_dtype) {
return {tmp_out_dtype};
}

PD_BUILD_OP(adjust_batch)
.Inputs({"tmp_out", "cum_offsets", "seq_lens_decoder", "seq_lens_encoder", paddle::Optional("output_padding_offset")})
.Outputs({"out"})
.Attrs({"max_input_length: int"})
.SetKernelFn(PD_KERNEL(AdjustBatch))
.SetInferShapeFn(PD_INFER_SHAPE(AdjustBatchInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AdjustBatchInferDtype));
10 changes: 10 additions & 0 deletions csrc/xpu/src/cmake_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,24 @@

set -e

rm -rf build

export PATH=/opt/output/work_dir/deps/cmake-3.26.0-linux-x86_64/bin:$PATH
# export XDNN_PATH=Paddle/build/third_party/xpu/src/extern_xpu/xdnn-ubuntu_x86_64/ # <path_to_xdnn>
# export XRE_PATH=Paddle/build/third_party/xpu/src/extern_xpu/xre-ubuntu_x86_64/ # <path_to_xre>
# export CLANG_PATH=xtdk-ubuntu_1604_x86_64 # <path_to_xtdk>
# export HOST_SYSROOT=/opt/compiler/gcc-8.2/bin/gcc # <path_to_gcc>

export XDNN_PATH=/opt/output/work_dir/paddle-deepseek/xpu_libs/xhpc/xdnn
export XRE_PATH=/opt/output/work_dir/paddle-deepseek/xpu_libs/xre
export CLANG_PATH=/opt/output/work_dir/paddle-deepseek/xpu_libs/xdnn_plugin/xtdk_output/xtdk-llvm15-ubuntu2004_x86_64
cd plugin
./cmake_build.sh
cd -

unset XDNN_PATH
unset XRE_PATH
unset CLANG_PATH

python -m pip uninstall paddlenlp_ops -y
python setup.py install
116 changes: 116 additions & 0 deletions csrc/xpu/src/gather_next_token.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"
#include <xft/xdnn_plugin.h>
namespace xftkernel = baidu::xpu::xftkernel;
std::vector<paddle::Tensor> GatherNextToken(const paddle::Tensor& tmp_out, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::optional<paddle::Tensor>& output_padding_offset,
int max_input_length) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
xpu::ctx_guard RAII_GUARD(xpu_ctx->x_context());
using XPUType = typename XPUTypeTrait<float16>::Type; // only support fp16
typedef paddle::float16 data_t;
const int dim = tmp_out.dims()[1];
const int bsz = cum_offsets.shape()[0];

std::vector<int> seq_lens_encoder_cpu(bsz, 0);
std::vector<int> seq_lens_decoder_cpu(bsz, 0);
std::vector<int> encoder_batch_idx; // 去除空隙的batch map
std::vector<int> decoder_batch_idx; // 去除空隙的batch map
std::vector<int> encoder_seq_lod;
int r = xpu_memcpy(seq_lens_encoder_cpu.data(),
seq_lens_encoder.data<int>(),
sizeof(int32_t) * bsz,
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
r = xpu_memcpy(seq_lens_decoder_cpu.data(),
seq_lens_decoder.data<int>(),
sizeof(int32_t) * bsz,
XPUMemcpyKind::XPU_DEVICE_TO_HOST);
int enc_batch = 0, dec_batch = 0;
int batch_offset = 0;
encoder_seq_lod.push_back(0);
for(int i = 0; i < bsz; ++i){
if(seq_lens_encoder_cpu[i] > 0){
enc_batch++;
encoder_batch_idx.push_back(i - batch_offset);
encoder_seq_lod.push_back(seq_lens_encoder_cpu[i]);
encoder_seq_lod[enc_batch] += encoder_seq_lod[enc_batch - 1];
}
else if(seq_lens_decoder_cpu[i] > 0){
dec_batch++;
decoder_batch_idx.push_back(i - batch_offset);
}
else{
batch_offset++;
}
}
int total_batch = enc_batch + dec_batch;
baidu::xpu::api::VectorParam<int32_t> encoder_seqs_lods_vp =
baidu::xpu::api::VectorParam<int32_t>{encoder_seq_lod.data(), enc_batch + 1, nullptr}
.to_xpu(RAII_GUARD);
baidu::xpu::api::VectorParam<int32_t> encoder_batch_map_vp =
baidu::xpu::api::VectorParam<int32_t>{encoder_batch_idx.data(), enc_batch, nullptr}
.to_xpu(RAII_GUARD);
baidu::xpu::api::VectorParam<int32_t> decoder_batch_map_vp =
baidu::xpu::api::VectorParam<int32_t>{decoder_batch_idx.data(), dec_batch, nullptr}
.to_xpu(RAII_GUARD);
auto out = paddle::full({total_batch, dim}, -2, tmp_out.type(), tmp_out.place());

r = xftkernel::xft_eb_gather_next_token<XPUType, XPUType>(
xpu_ctx->x_context(),
reinterpret_cast<const XPUType*>(tmp_out.data<data_t>()),
reinterpret_cast<XPUType*>(out.data<data_t>()),
encoder_seqs_lods_vp,
encoder_batch_map_vp,
decoder_batch_map_vp,
dim);
return {out};
}

std::vector<std::vector<int64_t>> GatherNextTokenInferShape(const std::vector<int64_t>& tmp_out_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_encoder_shape,
const paddle::optional<std::vector<int64_t>>& output_padding_offset_shape) {
if (output_padding_offset_shape) {
PD_THROW("speculative decoding is not supported in XPU.");
}
int64_t bsz = cum_offsets_shape[0];
int64_t dim_embed = tmp_out_shape[1];
return {{bsz, dim_embed}};
}

std::vector<paddle::DataType> GatherNextTokenInferDtype(const paddle::DataType& tmp_out_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::optional<paddle::DataType>& output_padding_offset_dtype) {
return {tmp_out_dtype};
}

PD_BUILD_OP(gather_next_token)
.Inputs({"tmp_out", "cum_offsets", "seq_lens_decoder", "seq_lens_encoder", paddle::Optional("output_padding_offset")})
.Outputs({"out"})
.Attrs({"max_input_length: int"})
.SetKernelFn(PD_KERNEL(GatherNextToken))
.SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype));
41 changes: 21 additions & 20 deletions csrc/xpu/src/get_padding_offset_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
#include "paddle/extension.h"
#include "xpu/plugin.h"

std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len) {
const paddle::Tensor& seq_len,
const paddle::optional<paddle::Tensor>& draft_tokens,
const paddle::optional<paddle::Tensor>& seq_lens_encoder) {
if (draft_tokens) {
PD_THROW("speculative decoding is not supported in XPU.");
}
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
Expand Down Expand Up @@ -60,35 +65,31 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor& input_ids,
cu_seqlens_k};
}

std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
std::vector<std::vector<int64_t>> GetPaddingOffsetV2InferShape(
const std::vector<int64_t>& input_ids_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& token_num_shape,
const std::vector<int64_t>& seq_len_shape) {
const std::vector<int64_t>& seq_len_shape,
const std::vector<int64_t>& draft_tokens_shape,
const std::vector<int64_t>& seq_lens_encoder_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
}

std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
std::vector<paddle::DataType> GetPaddingOffsetV2InferDtype(
const paddle::DataType& input_ids_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& token_num_dtype,
const paddle::DataType& seq_len_dtype) {
return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype,
seq_len_dtype};
const paddle::DataType& seq_len_dtype,
const paddle::DataType& draft_tokens_dtype,
const paddle::DataType& seq_lens_encoder_dtype) {
return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype};
}

PD_BUILD_OP(get_padding_offset_v2)
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
.Outputs({"x_remove_padding",
"cum_offsets_out",
"padding_offset",
"cu_seqlens_q",
"cu_seqlens_k"})
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetInferDtype));
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len", paddle::Optional("draft_tokens"), paddle::Optional("seq_lens_encoder"),})
.Outputs({"x_remove_padding", "cum_offsets_out", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"})
.SetKernelFn(PD_KERNEL(GetPaddingOffsetV2))
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetV2InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetV2InferDtype));
53 changes: 53 additions & 0 deletions csrc/xpu/src/get_position_ids.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
#include <xft/xdnn_plugin.h>
#include "xblas_legacy_api.h"

namespace xftkernel = baidu::xpu::xftkernel;
namespace api = baidu::xpu::api;
// namespace xblas = baidu::xpu::xblas;

void GetPositionIdsKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& position_ids) {

phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);

const int bs = seq_lens_encoder.shape()[0];

int ret = baidu::xpu::api::plugin::get_position_ids(
xpu_ctx->x_context(),
seq_lens_encoder.data<int32_t>(),
seq_lens_decoder.data<int32_t>(),
seq_lens_this_time.data<int32_t>(),
const_cast<int32_t*>(position_ids.data<int32_t>()),
bs
);
PD_CHECK(ret == 0, "api::plugin::get_position_ids failed");
}

PD_BUILD_OP(get_position_ids)
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time", "position_ids"})
.Outputs({"position_ids_out"})
.SetInplaceMap({{"position_ids", "position_ids_out"}})
.SetKernelFn(PD_KERNEL(GetPositionIdsKernel));
Loading
Loading