Skip to content

Commit d8e6b25

Browse files
authored
[INTEL HPU] add fused block atten (#1706)
1 parent 3db38f4 commit d8e6b25

File tree

10 files changed

+2639
-329
lines changed

10 files changed

+2639
-329
lines changed

backends/intel_hpu/custom_ops/llama_infer/fused_block_attention.cc

Lines changed: 1818 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "habanalabs/perf_lib_layer_params.h"
16+
#include "kernels/funcs.h"
17+
#include "kernels/hpu_funcs.h"
18+
#include "kernels/hpu_operator.h"
19+
#include "paddle/extension.h"
20+
#include "utils/utils.h"
21+
22+
namespace custom_kernel {
23+
24+
struct FusedRmsMlpResParams {
25+
ns_LayerNormKernel::Params rmsnorm_params;
26+
synSplitParams split_params;
27+
};
28+
29+
class FusedRmsMlpRes : public HpuFusedOperator {
30+
public:
31+
explicit FusedRmsMlpRes(synDataType dtype)
32+
: HpuFusedOperator("fused_rms_mlp_res_fwd_", false), dtype_(dtype) {}
33+
template <typename T>
34+
void AddNode(ConvertTensors& ct, FusedRmsMlpResParams params) {
35+
auto ins = ct.GetTensors();
36+
auto outs = ct.GetTensors(false);
37+
38+
synGEMMParams gemm_params;
39+
gemm_params.transpose_a = false;
40+
gemm_params.transpose_b = false;
41+
42+
synSectionHandle section = createSection();
43+
auto hidden_states = createTensorFromCT(&ct, 0);
44+
auto residual_input = createTensorFromCT(&ct, 4, true, section);
45+
auto residual_out = createTensorFromCT(&ct, 1, false, section);
46+
47+
std::vector<synTensor> add_residual_in;
48+
add_residual_in.push_back(hidden_states);
49+
add_residual_in.push_back(residual_input);
50+
51+
std::vector<synTensor> add_residual_out;
52+
add_residual_out.push_back(residual_out);
53+
54+
AddNodeAdd<T>(add_residual_in, add_residual_out, guid_ + "add_residual");
55+
56+
auto ln_scales = createTensorFromCT(&ct, 1);
57+
std::vector<synTensor> rmsnorm_inputs;
58+
rmsnorm_inputs.push_back(residual_out);
59+
rmsnorm_inputs.push_back(ln_scales);
60+
61+
auto tmp_dims = ins[0].dims;
62+
tmp_dims[2] = 1;
63+
auto norm_out = createTensorNoPresist("norm_out", ins[0].type, ins[0].dims);
64+
auto norm_var = createTensorNoPresist("norm_var", ins[0].type, tmp_dims);
65+
std::vector<synTensor> rmsnorm_outputs;
66+
rmsnorm_outputs.push_back(norm_out);
67+
rmsnorm_outputs.push_back(norm_var);
68+
69+
AddNodeRmsNorm<T>(rmsnorm_inputs,
70+
rmsnorm_outputs,
71+
params.rmsnorm_params,
72+
guid_ + "rmsnorm");
73+
74+
auto proj_weight = createTensorFromCT(&ct, 2);
75+
std::vector<int64_t> proj_dims = {
76+
ins[0].dims[0], ins[0].dims[1], ins[2].dims[1]};
77+
auto proj_out = createTensorNoPresist("proj_out", ins[0].type, proj_dims);
78+
79+
std::vector<synTensor> proj_inputs;
80+
proj_inputs.push_back(norm_out);
81+
proj_inputs.push_back(proj_weight);
82+
std::vector<synTensor> proj_outputs;
83+
proj_outputs.push_back(proj_out);
84+
85+
AddNodeGemm(proj_inputs, proj_outputs, gemm_params, guid_ + "gemm_up_proj");
86+
87+
std::vector<int64_t> split_out_dims = {
88+
proj_dims[0], proj_dims[1], proj_dims[2] / 2};
89+
auto gate_out =
90+
createTensorNoPresist("gate_out", ins[0].type, split_out_dims);
91+
auto up_out = createTensorNoPresist("up_out", ins[0].type, split_out_dims);
92+
auto down_weight = createTensorFromCT(&ct, 3);
93+
94+
std::vector<synTensor> split_inputs;
95+
split_inputs.push_back(proj_out);
96+
std::vector<synTensor> split_outputs;
97+
split_outputs.push_back(gate_out);
98+
split_outputs.push_back(up_out);
99+
100+
AddNodeSplit(
101+
split_inputs, split_outputs, params.split_params, guid_ + "split");
102+
103+
auto silu_out =
104+
createTensorNoPresist("silu_out", ins[0].type, split_out_dims);
105+
std::vector<synTensor> silu_inputs;
106+
silu_inputs.push_back(gate_out);
107+
std::vector<synTensor> silu_outputs;
108+
silu_outputs.push_back(silu_out);
109+
110+
AddNodeSilu<T>(silu_inputs, silu_outputs, guid_ + "silu");
111+
112+
auto multi_out =
113+
createTensorNoPresist("multi_out", ins[0].type, split_out_dims);
114+
std::vector<synTensor> multi_inputs;
115+
multi_inputs.push_back(silu_out);
116+
multi_inputs.push_back(up_out);
117+
std::vector<synTensor> multi_outputs;
118+
multi_outputs.push_back(multi_out);
119+
120+
AddNodeMultiply<T>(multi_inputs, multi_outputs, guid_ + "_multi");
121+
122+
auto mlp_out = createTensorFromCT(&ct, 0, false);
123+
std::vector<synTensor> down_inputs;
124+
down_inputs.push_back(multi_out);
125+
down_inputs.push_back(down_weight);
126+
std::vector<synTensor> down_outputs;
127+
down_outputs.push_back(mlp_out);
128+
129+
AddNodeGemm(
130+
down_inputs, down_outputs, gemm_params, guid_ + "gemm_down_proj");
131+
}
132+
133+
protected:
134+
synDataType dtype_;
135+
};
136+
137+
template <typename T, typename Context>
138+
void FusedRmsMlpResKernel(const Context& dev_ctx,
139+
const phi::DenseTensor& x,
140+
const phi::DenseTensor& residual,
141+
const phi::DenseTensor& ln_scales,
142+
const phi::DenseTensor& proj_weight,
143+
const phi::DenseTensor& down_weight,
144+
const phi::Scalar& epsilon,
145+
phi::DenseTensor* out) {
146+
// allocate memory on device.
147+
dev_ctx.template Alloc<T>(out);
148+
if (out->numel() == 0) {
149+
return;
150+
}
151+
152+
std::vector<int64_t> ln_scales_dims =
153+
phi::vectorize<int64_t>(ln_scales.dims());
154+
155+
const phi::Scalar axis_scalar = proj_weight.dims().size() - 1;
156+
int64_t axis = axis_scalar.to<int64_t>();
157+
if (axis < 0) {
158+
axis = proj_weight.dims().size() + axis;
159+
}
160+
FusedRmsMlpResParams params;
161+
memset(reinterpret_cast<void*>(&params), 0x00, sizeof(FusedRmsMlpResParams));
162+
params.rmsnorm_params.epsValid = true;
163+
params.rmsnorm_params.eps = epsilon.to<float>();
164+
165+
params.split_params = {{0}};
166+
params.split_params.axis = proj_weight.dims().size() - 1 - axis;
167+
168+
ConvertTensors ct;
169+
ct.Add(x);
170+
ct.Add(ln_scales);
171+
ct.Add(proj_weight);
172+
ct.Add(down_weight);
173+
ct.Add(residual);
174+
ct.Add(*out, false);
175+
ct.Add(residual, false);
176+
std::vector<DIMS> inputs_dims = ct.GetDims();
177+
178+
OpCacheOperator op_info;
179+
op_info.prepareOpInfo<T, FusedRmsMlpResParams>(
180+
"FusedRmsMlpResKernel", inputs_dims, &params);
181+
auto recipe = op_info.GetRecipe();
182+
183+
if (recipe == nullptr) {
184+
FusedRmsMlpRes op(op_info.datatype_);
185+
op.AddNode<T>(ct, params);
186+
op.Compile();
187+
op_info.setOp(op);
188+
189+
recipe = op_info.GetRecipe();
190+
}
191+
192+
std::map<std::string, uint64_t> tensors = ct.GetDeviceAddr();
193+
RecipeRunner runner(recipe);
194+
runner.Run(reinterpret_cast<C_Stream>(dev_ctx.stream()), tensors);
195+
}
196+
197+
} // namespace custom_kernel
198+
199+
template <typename Context>
200+
void CallFusedRmsMlpResKernel(const Context& dev_ctx,
201+
const phi::DenseTensor& x,
202+
const phi::DenseTensor& residual,
203+
const phi::DenseTensor& ln_scales,
204+
const phi::DenseTensor& proj_weight,
205+
const phi::DenseTensor& down_weight,
206+
const phi::Scalar& epsilon,
207+
phi::DenseTensor* out) {
208+
if (x.dtype() == phi::DataType::BFLOAT16) {
209+
custom_kernel::FusedRmsMlpResKernel<phi::dtype::bfloat16>(dev_ctx,
210+
x,
211+
residual,
212+
ln_scales,
213+
proj_weight,
214+
down_weight,
215+
epsilon,
216+
out);
217+
} else {
218+
throw std::runtime_error("Unsupported data type for FusedRmsMlpResKernel");
219+
}
220+
}
221+
222+
std::vector<paddle::Tensor> FusedRmsMlpResForward(
223+
const paddle::Tensor& x,
224+
const paddle::Tensor& ln_scales,
225+
const paddle::Tensor& proj_weight,
226+
const paddle::Tensor& down_weight,
227+
const paddle::Tensor& residual,
228+
const float epsilon) {
229+
auto dev_ctx = static_cast<const phi::CustomContext*>(
230+
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));
231+
232+
auto x_tensor = static_cast<const phi::DenseTensor*>(x.impl().get());
233+
auto residual_tensor =
234+
static_cast<const phi::DenseTensor*>(residual.impl().get());
235+
236+
auto ln_scales_tensor =
237+
static_cast<const phi::DenseTensor*>(ln_scales.impl().get());
238+
auto down_tensor =
239+
static_cast<const phi::DenseTensor*>(down_weight.impl().get());
240+
auto proj_tensor =
241+
static_cast<const phi::DenseTensor*>(proj_weight.impl().get());
242+
243+
auto out_tensor = std::make_shared<phi::DenseTensor>();
244+
out_tensor->Resize(x_tensor->dims());
245+
246+
CallFusedRmsMlpResKernel(*dev_ctx,
247+
*x_tensor,
248+
*residual_tensor,
249+
*ln_scales_tensor,
250+
*proj_tensor,
251+
*down_tensor,
252+
phi::Scalar(epsilon),
253+
out_tensor.get());
254+
255+
paddle::Tensor out(out_tensor);
256+
257+
return {out};
258+
}
259+
260+
std::vector<std::vector<int64_t>> FusedRmsMlpResInferShape(
261+
const std::vector<int64_t>& x_shape,
262+
const std::vector<int64_t>& ln_scales_shape,
263+
const std::vector<int64_t>& proj_weight_shape,
264+
const std::vector<int64_t>& down_weight_shape,
265+
const std::vector<int64_t>& residual_shape) {
266+
return {x_shape, residual_shape};
267+
}
268+
269+
std::vector<paddle::DataType> FusedRmsMlpResInferDtype(
270+
const paddle::DataType& x_dtype,
271+
const paddle::DataType& ln_scales_dtype,
272+
const paddle::DataType& proj_weight_dtype,
273+
const paddle::DataType& down_weight_dtype,
274+
const paddle::DataType& residual_dtype) {
275+
return {x_dtype, residual_dtype};
276+
}
277+
278+
PD_BUILD_OP(fused_rms_mlp_res)
279+
.Inputs({"x", "ln_scales", "proj_weight", "down_weight", "residual_in"})
280+
.Outputs({"out"})
281+
.Attrs({"epsilon: float"})
282+
.SetKernelFn(PD_KERNEL(FusedRmsMlpResForward))
283+
.SetInferShapeFn(PD_INFER_SHAPE(FusedRmsMlpResInferShape))
284+
.SetInferDtypeFn(PD_INFER_DTYPE(FusedRmsMlpResInferDtype));

0 commit comments

Comments
 (0)