Skip to content

[INTEL HPU] add fused block atten #1706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,818 changes: 1,818 additions & 0 deletions backends/intel_hpu/custom_ops/llama_infer/fused_block_attention.cc

Large diffs are not rendered by default.

284 changes: 284 additions & 0 deletions backends/intel_hpu/custom_ops/llama_infer/fused_rms_mlp_add.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "habanalabs/perf_lib_layer_params.h"
#include "kernels/funcs.h"
#include "kernels/hpu_funcs.h"
#include "kernels/hpu_operator.h"
#include "paddle/extension.h"
#include "utils/utils.h"

namespace custom_kernel {

struct FusedRmsMlpResParams {
ns_LayerNormKernel::Params rmsnorm_params;
synSplitParams split_params;
};

class FusedRmsMlpRes : public HpuFusedOperator {
public:
explicit FusedRmsMlpRes(synDataType dtype)
: HpuFusedOperator("fused_rms_mlp_res_fwd_", false), dtype_(dtype) {}
template <typename T>
void AddNode(ConvertTensors& ct, FusedRmsMlpResParams params) {
auto ins = ct.GetTensors();
auto outs = ct.GetTensors(false);

synGEMMParams gemm_params;
gemm_params.transpose_a = false;
gemm_params.transpose_b = false;

synSectionHandle section = createSection();
auto hidden_states = createTensorFromCT(&ct, 0);
auto residual_input = createTensorFromCT(&ct, 4, true, section);
auto residual_out = createTensorFromCT(&ct, 1, false, section);

std::vector<synTensor> add_residual_in;
add_residual_in.push_back(hidden_states);
add_residual_in.push_back(residual_input);

std::vector<synTensor> add_residual_out;
add_residual_out.push_back(residual_out);

AddNodeAdd<T>(add_residual_in, add_residual_out, guid_ + "add_residual");

auto ln_scales = createTensorFromCT(&ct, 1);
std::vector<synTensor> rmsnorm_inputs;
rmsnorm_inputs.push_back(residual_out);
rmsnorm_inputs.push_back(ln_scales);

auto tmp_dims = ins[0].dims;
tmp_dims[2] = 1;
auto norm_out = createTensorNoPresist("norm_out", ins[0].type, ins[0].dims);
auto norm_var = createTensorNoPresist("norm_var", ins[0].type, tmp_dims);
std::vector<synTensor> rmsnorm_outputs;
rmsnorm_outputs.push_back(norm_out);
rmsnorm_outputs.push_back(norm_var);

AddNodeRmsNorm<T>(rmsnorm_inputs,
rmsnorm_outputs,
params.rmsnorm_params,
guid_ + "rmsnorm");

auto proj_weight = createTensorFromCT(&ct, 2);
std::vector<int64_t> proj_dims = {
ins[0].dims[0], ins[0].dims[1], ins[2].dims[1]};
auto proj_out = createTensorNoPresist("proj_out", ins[0].type, proj_dims);

std::vector<synTensor> proj_inputs;
proj_inputs.push_back(norm_out);
proj_inputs.push_back(proj_weight);
std::vector<synTensor> proj_outputs;
proj_outputs.push_back(proj_out);

AddNodeGemm(proj_inputs, proj_outputs, gemm_params, guid_ + "gemm_up_proj");

std::vector<int64_t> split_out_dims = {
proj_dims[0], proj_dims[1], proj_dims[2] / 2};
auto gate_out =
createTensorNoPresist("gate_out", ins[0].type, split_out_dims);
auto up_out = createTensorNoPresist("up_out", ins[0].type, split_out_dims);
auto down_weight = createTensorFromCT(&ct, 3);

std::vector<synTensor> split_inputs;
split_inputs.push_back(proj_out);
std::vector<synTensor> split_outputs;
split_outputs.push_back(gate_out);
split_outputs.push_back(up_out);

AddNodeSplit(
split_inputs, split_outputs, params.split_params, guid_ + "split");

auto silu_out =
createTensorNoPresist("silu_out", ins[0].type, split_out_dims);
std::vector<synTensor> silu_inputs;
silu_inputs.push_back(gate_out);
std::vector<synTensor> silu_outputs;
silu_outputs.push_back(silu_out);

AddNodeSilu<T>(silu_inputs, silu_outputs, guid_ + "silu");

auto multi_out =
createTensorNoPresist("multi_out", ins[0].type, split_out_dims);
std::vector<synTensor> multi_inputs;
multi_inputs.push_back(silu_out);
multi_inputs.push_back(up_out);
std::vector<synTensor> multi_outputs;
multi_outputs.push_back(multi_out);

AddNodeMultiply<T>(multi_inputs, multi_outputs, guid_ + "_multi");

auto mlp_out = createTensorFromCT(&ct, 0, false);
std::vector<synTensor> down_inputs;
down_inputs.push_back(multi_out);
down_inputs.push_back(down_weight);
std::vector<synTensor> down_outputs;
down_outputs.push_back(mlp_out);

AddNodeGemm(
down_inputs, down_outputs, gemm_params, guid_ + "gemm_down_proj");
}

protected:
synDataType dtype_;
};

template <typename T, typename Context>
void FusedRmsMlpResKernel(const Context& dev_ctx,
const phi::DenseTensor& x,
const phi::DenseTensor& residual,
const phi::DenseTensor& ln_scales,
const phi::DenseTensor& proj_weight,
const phi::DenseTensor& down_weight,
const phi::Scalar& epsilon,
phi::DenseTensor* out) {
// allocate memory on device.
dev_ctx.template Alloc<T>(out);
if (out->numel() == 0) {
return;
}

std::vector<int64_t> ln_scales_dims =
phi::vectorize<int64_t>(ln_scales.dims());

const phi::Scalar axis_scalar = proj_weight.dims().size() - 1;
int64_t axis = axis_scalar.to<int64_t>();
if (axis < 0) {
axis = proj_weight.dims().size() + axis;
}
FusedRmsMlpResParams params;
memset(reinterpret_cast<void*>(&params), 0x00, sizeof(FusedRmsMlpResParams));
params.rmsnorm_params.epsValid = true;
params.rmsnorm_params.eps = epsilon.to<float>();

params.split_params = {{0}};
params.split_params.axis = proj_weight.dims().size() - 1 - axis;

ConvertTensors ct;
ct.Add(x);
ct.Add(ln_scales);
ct.Add(proj_weight);
ct.Add(down_weight);
ct.Add(residual);
ct.Add(*out, false);
ct.Add(residual, false);
std::vector<DIMS> inputs_dims = ct.GetDims();

OpCacheOperator op_info;
op_info.prepareOpInfo<T, FusedRmsMlpResParams>(
"FusedRmsMlpResKernel", inputs_dims, &params);
auto recipe = op_info.GetRecipe();

if (recipe == nullptr) {
FusedRmsMlpRes op(op_info.datatype_);
op.AddNode<T>(ct, params);
op.Compile();
op_info.setOp(op);

recipe = op_info.GetRecipe();
}

std::map<std::string, uint64_t> tensors = ct.GetDeviceAddr();
RecipeRunner runner(recipe);
runner.Run(reinterpret_cast<C_Stream>(dev_ctx.stream()), tensors);
}

} // namespace custom_kernel

template <typename Context>
void CallFusedRmsMlpResKernel(const Context& dev_ctx,
const phi::DenseTensor& x,
const phi::DenseTensor& residual,
const phi::DenseTensor& ln_scales,
const phi::DenseTensor& proj_weight,
const phi::DenseTensor& down_weight,
const phi::Scalar& epsilon,
phi::DenseTensor* out) {
if (x.dtype() == phi::DataType::BFLOAT16) {
custom_kernel::FusedRmsMlpResKernel<phi::dtype::bfloat16>(dev_ctx,
x,
residual,
ln_scales,
proj_weight,
down_weight,
epsilon,
out);
} else {
throw std::runtime_error("Unsupported data type for FusedRmsMlpResKernel");
}
}

std::vector<paddle::Tensor> FusedRmsMlpResForward(
const paddle::Tensor& x,
const paddle::Tensor& ln_scales,
const paddle::Tensor& proj_weight,
const paddle::Tensor& down_weight,
const paddle::Tensor& residual,
const float epsilon) {
auto dev_ctx = static_cast<const phi::CustomContext*>(
paddle::experimental::DeviceContextPool::Instance().Get(x.place()));

auto x_tensor = static_cast<const phi::DenseTensor*>(x.impl().get());
auto residual_tensor =
static_cast<const phi::DenseTensor*>(residual.impl().get());

auto ln_scales_tensor =
static_cast<const phi::DenseTensor*>(ln_scales.impl().get());
auto down_tensor =
static_cast<const phi::DenseTensor*>(down_weight.impl().get());
auto proj_tensor =
static_cast<const phi::DenseTensor*>(proj_weight.impl().get());

auto out_tensor = std::make_shared<phi::DenseTensor>();
out_tensor->Resize(x_tensor->dims());

CallFusedRmsMlpResKernel(*dev_ctx,
*x_tensor,
*residual_tensor,
*ln_scales_tensor,
*proj_tensor,
*down_tensor,
phi::Scalar(epsilon),
out_tensor.get());

paddle::Tensor out(out_tensor);

return {out};
}

std::vector<std::vector<int64_t>> FusedRmsMlpResInferShape(
const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& ln_scales_shape,
const std::vector<int64_t>& proj_weight_shape,
const std::vector<int64_t>& down_weight_shape,
const std::vector<int64_t>& residual_shape) {
return {x_shape, residual_shape};
}

std::vector<paddle::DataType> FusedRmsMlpResInferDtype(
const paddle::DataType& x_dtype,
const paddle::DataType& ln_scales_dtype,
const paddle::DataType& proj_weight_dtype,
const paddle::DataType& down_weight_dtype,
const paddle::DataType& residual_dtype) {
return {x_dtype, residual_dtype};
}

PD_BUILD_OP(fused_rms_mlp_res)
.Inputs({"x", "ln_scales", "proj_weight", "down_weight", "residual_in"})
.Outputs({"out"})
.Attrs({"epsilon: float"})
.SetKernelFn(PD_KERNEL(FusedRmsMlpResForward))
.SetInferShapeFn(PD_INFER_SHAPE(FusedRmsMlpResInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(FusedRmsMlpResInferDtype));
427 changes: 110 additions & 317 deletions backends/intel_hpu/custom_ops/llama_infer/fused_rms_qkv_rope_t.cc

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -89,6 +89,14 @@ void pad_fill(const T* input_p,
}
}

template <typename T>
void pad_fill(const T* input_p, T* padded, std::vector<int> valid_batches) {
#pragma omp parallel for num_threads(OMP_THREAD_NUM)
for (int i = 0; i < static_cast<int>(valid_batches.size()); ++i) {
padded[i] = input_p[valid_batches[i]];
}
}

// in: seq_lens_decoder, block_tables
// out: block_indices, block_offset
// return last_block_pos, seq_lens
@@ -151,7 +159,6 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
auto hpu_place = rope_emb.place();
auto dev_ctx = static_cast<const phi::CustomContext*>(
paddle::experimental::DeviceContextPool::Instance().Get(hpu_place));
auto input_ids_cpu = input_ids.copy_to(paddle::CPUPlace(), true);
auto block_tables_cpu = block_tables.copy_to(paddle::CPUPlace(), true);
auto seq_lens_encoder_cpu =
seq_lens_encoder.copy_to(paddle::CPUPlace(), true);
@@ -178,6 +185,7 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(

if (enc_count > 0) {
int total_batch = find_bucket(enc_count, batch_step, max_batches);
auto input_ids_cpu = input_ids.copy_to(paddle::CPUPlace(), true);

int max_buckets = (max_enc_len + block_size - 1) / block_size;
int max_prompt_len = max_buckets * block_size;
@@ -238,22 +246,22 @@ std::vector<paddle::Tensor> PrepareBlockMetadata(
} else if (dec_count > 0) {
int total_batch = find_bucket(dec_count, batch_step, max_batches);

auto input_ids_column_0 =
paddle::experimental::slice(input_ids, {1}, {0}, {1}, {}, {});
auto input_ids_cpu = input_ids_column_0.copy_to(paddle::CPUPlace(), true);

auto src_padded = paddle::full(
{total_batch}, 0, paddle::DataType::INT64, paddle::CPUPlace());
pad_fill<int64_t>(const_cast<int64_t*>(input_ids_cpu.data<int64_t>()),
reinterpret_cast<int64_t*>(src_padded.data<int64_t>()),
valid_batches_dec,
max_seq_len,
1);
valid_batches_dec);

auto seq_lens_padded = paddle::full(
{total_batch}, 0, paddle::DataType::INT32, paddle::CPUPlace());
pad_fill<int32_t>(
const_cast<int32_t*>(seq_lens_decoder_cpu.data<int32_t>()),
reinterpret_cast<int32_t*>(seq_lens_padded.data<int32_t>()),
valid_batches_dec,
1,
1);
valid_batches_dec);

std::shared_ptr<phi::DenseTensor> seq_lens_padded_hpu =
std::make_shared<phi::DenseTensor>();
81 changes: 80 additions & 1 deletion backends/intel_hpu/custom_ops/python/paddlenlp_ops/layers.py
Original file line number Diff line number Diff line change
@@ -92,12 +92,13 @@ def __init__(self, ln_scales, qkv_weights, epsilon, head_dim, num_head):
self.head_dim = head_dim
self.num_head = num_head

def forward(self, i, src, rotary_embs):
def forward(self, i, src, rotary_embs, residual):
query_states, kv_states = fused_rms_qkv_rope_t(
src,
self.ln_scales[i],
self.qkv_weights[i],
rotary_embs,
residual,
self.epsilon,
self.head_dim,
self.num_head,
@@ -208,6 +209,64 @@ def forward(
return out_linear_out


class Fused_Block_Attention(paddle.nn.Layer):
def __init__(
self,
ln_scales,
qkv_weights,
epsilon,
head_dim,
num_head,
scaling_factor,
linear_weights,
):
super().__init__()
self.ln_scales = ln_scales
self.qkv_weights = qkv_weights
self.epsilon = epsilon
self.head_dim = head_dim
self.num_head = num_head
self.scaling_factor = scaling_factor
self.linear_weights = linear_weights

def forward(
self,
i,
src,
residual,
rotary_embs,
k_caches,
v_caches,
block_groups,
block_list,
block_mapping,
block_bias,
block_indices,
block_offsets,
):
out_linear_out = fused_block_attention(
src,
residual,
rotary_embs,
k_caches,
v_caches,
block_groups,
block_list,
block_mapping,
block_bias,
block_indices,
block_offsets,
self.ln_scales[i],
self.qkv_weights[i],
self.linear_weights[i],
self.epsilon,
self.head_dim,
self.num_head,
self.scaling_factor,
)
return out_linear_out


class Fused_Mlp(paddle.nn.Layer):
def __init__(self, proj_weight, up_weight, down_weight):
super().__init__()
@@ -244,6 +303,26 @@ def forward(self, i, x):
return fused_rms_mlp_out


class Fused_Rms_Mlp_Res(paddle.nn.Layer):
def __init__(self, ln_scales, epsilon, proj_weight, down_weight):
super().__init__()
self.ln_scales = ln_scales
self.epsilon = epsilon
self.proj_weight = proj_weight
self.down_weight = down_weight

def forward(self, i, x, residual):
fused_rms_mlp_out = fused_rms_mlp_res(
x,
self.ln_scales[i],
self.proj_weight[i],
self.down_weight[i],
residual,
self.epsilon,
)
return fused_rms_mlp_out


class Prepare_Block_Metadata(paddle.nn.Layer):
def __init__(self, block_size):
super().__init__()
225 changes: 225 additions & 0 deletions backends/intel_hpu/custom_ops/tests/test_fused_block_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddlenlp_ops

paddle.device.set_device("intel_hpu:1")

# paddle.seed(102)


class TestFusedBlockAttention:
def __init__(self):
self.head_dim = 128
self.num_head = 32
self.kv_num_heads = 32
self.hidden_size = self.num_head * self.head_dim

self.epsilon = 1e-06

self.use_neox = True
self.position_offset = 0
self.rope_theta = 10000

def init_decode_params(self):
self.test_name = "TestFusedBlockAttentionDecode"
self.batch_size = 16
self.seq_len = 1
self.block_size = 128
self.num_of_block = 32
self.total_block_num = 20
position_id = paddle.to_tensor([80])
self.position_ids = paddle.expand(
position_id, shape=[self.batch_size, self.seq_len]
)

def create_tensors(self):
self.k_cache = (
paddle.rand(
[
self.total_block_num,
self.block_size,
self.kv_num_heads,
self.head_dim,
],
dtype=paddle.float32,
)
* 1000
)
self.k_cache = self.k_cache.to(paddle.bfloat16)
self.k_cache_test = self.k_cache.clone()
self.v_cache = (
paddle.rand(
[
self.total_block_num,
self.block_size,
self.kv_num_heads,
self.head_dim,
],
dtype=paddle.float32,
)
* 1000
)
self.v_cache = self.v_cache.to(paddle.bfloat16)
self.v_cache_test = self.v_cache.clone()

self.input_ids = paddle.zeros(
[self.batch_size, self.seq_len], dtype=paddle.bfloat16
)
self.src = paddle.rand(
[self.batch_size, self.seq_len, self.hidden_size], dtype=paddle.float32
).to(paddle.bfloat16)
self.residual = paddle.rand(
[self.batch_size, self.seq_len, self.hidden_size], dtype=paddle.float32
).to(paddle.bfloat16)
self.residual_test = self.residual.clone()

self.ln_scales = paddle.rand([self.hidden_size], dtype=paddle.bfloat16)
self.qkv_weights = paddle.rand(
[self.hidden_size * 3, self.hidden_size], dtype=paddle.float32
)
self.qkv_weights = self.qkv_weights.to(paddle.bfloat16)

self.linear_weights = paddle.rand(
[self.hidden_size, self.hidden_size], dtype=paddle.float32
).to(paddle.bfloat16)

self.head_dim_shape_tensor = paddle.ones(self.head_dim, dtype="int8")
self.new_rope = paddlenlp_ops.fused_get_rotary_embedding(
self.input_ids,
self.position_ids,
self.head_dim_shape_tensor,
self.position_offset,
self.rope_theta,
self.use_neox,
).to(paddle.bfloat16)

self.block_indices = paddle.randint(
0,
self.total_block_num,
[
self.batch_size,
],
dtype=paddle.int32,
)
self.block_offsets = paddle.randint(
0,
self.block_size,
[
self.batch_size,
],
dtype=paddle.int32,
)

self.block_groups = paddle.randint(
0,
self.batch_size,
[
self.num_of_block,
],
dtype=paddle.int32,
)
self.block_list = paddle.randint(
0,
self.num_of_block,
[
self.num_of_block,
],
dtype=paddle.int32,
)
self.block_mapping = paddle.randint(
0, 2, [self.num_of_block, self.batch_size], dtype=paddle.int32
).to(paddle.bfloat16)
self.block_bias = paddle.rand(
[self.num_of_block, self.block_size], dtype=paddle.bfloat16
)

def run_test(self):
query_states, key_value_states = paddlenlp_ops.fused_rms_qkv_rope_t(
self.src,
self.ln_scales,
self.qkv_weights,
self.new_rope.transpose([0, 1, 3, 2, 4]),
self.residual,
self.epsilon,
self.head_dim,
self.num_head,
)
key_states = key_value_states[0].squeeze(1)
value_states = key_value_states[1].squeeze(1)

self.k_cache.index_put_((self.block_indices, self.block_offsets), key_states)
self.v_cache.index_put_((self.block_indices, self.block_offsets), value_states)

out_linear_out_ref = paddlenlp_ops.fused_flatpa_proj(
query_states,
self.k_cache,
self.v_cache,
self.block_groups,
self.block_list,
self.block_mapping,
self.block_bias,
self.linear_weights,
scaling_factor=self.head_dim**-0.5,
)

out_linear_out = paddlenlp_ops.fused_block_attention(
self.src,
self.residual_test,
self.new_rope.transpose([0, 1, 3, 2, 4]),
self.k_cache_test,
self.v_cache_test,
self.block_groups,
self.block_list,
self.block_mapping,
self.block_bias,
self.block_indices,
self.block_offsets,
self.ln_scales,
self.qkv_weights,
self.linear_weights,
self.epsilon,
self.head_dim,
self.num_head,
scaling_factor=self.head_dim**-0.5,
)

assert (
(out_linear_out_ref == out_linear_out).all().item()
), f"Test failed for {self.test_name} fused_block_attention out_linear_out"
assert (
(self.k_cache == self.k_cache_test).all().item()
), f"Test failed for {self.test_name} fused_block_attention k_cache"
assert (
(self.v_cache == self.v_cache_test).all().item()
), f"Test failed for {self.test_name} fused_block_attention v_cache"
assert (
(self.residual == self.residual_test).all().item()
), f"Test failed for {self.test_name} fused_block_attention residual"

# ===============summary==============
print(f"Test Pass for {self.test_name} testcase")


class test_case_decode(TestFusedBlockAttention):
def __init__(self):
super().__init__()
self.init_decode_params()
self.create_tensors()


if __name__ == "__main__":
test_1 = test_case_decode()
test_1.run_test()
50 changes: 49 additions & 1 deletion backends/intel_hpu/custom_ops/tests/test_fused_rms_mlp.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,9 @@ def init_data(
x = paddle.rand(
[batch_size, seqence_len, hidden_size], dtype=paddle.float32
).to(paddle.bfloat16)
residual = paddle.rand(
[batch_size, seqence_len, hidden_size], dtype=paddle.float32
).to(paddle.bfloat16)

ln_scales = paddle.rand([hidden_size], dtype=paddle.bfloat16)
gate_weight = paddle.normal(
@@ -49,7 +52,16 @@ def init_data(

epsilon = 1e-06

return x, ln_scales, proj_weight, gate_weight, up_weight, down_weight, epsilon
return (
x,
ln_scales,
proj_weight,
gate_weight,
up_weight,
down_weight,
residual,
epsilon,
)


def ref_rms_mlp(
@@ -90,8 +102,11 @@ def __init__(self):
self.gate_weight,
self.up_weight,
self.down_weight,
self.residual,
self.epsilon,
) = init_data()
self.x = self.x + self.residual
self.residual = self.x

def forward(self):
mlp_out_ref = ref_rms_mlp(
@@ -115,8 +130,11 @@ def __init__(self):
_,
_,
self.down_weight,
self.residual,
self.epsilon,
) = init_data()
self.x = self.x + self.residual
self.residual = self.x

def forward(self):
fused_rms_mlp_out = paddlenlp_ops.fused_rms_mlp(
@@ -129,6 +147,32 @@ def forward(self):
return fused_rms_mlp_out


class fusedRmsMlpResOP(paddle.nn.Layer):
def __init__(self):
super().__init__()
(
self.x,
self.ln_scales,
self.proj_weight,
_,
_,
self.down_weight,
self.residual,
self.epsilon,
) = init_data()

def forward(self):
fused_rms_mlp_out = paddlenlp_ops.fused_rms_mlp_res(
self.x,
self.ln_scales,
self.proj_weight,
self.down_weight,
self.residual,
self.epsilon,
)
return fused_rms_mlp_out


def run_profile(my_profile_func):
prof = profiler.Profiler(
targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.CUSTOM_DEVICE],
@@ -144,11 +188,15 @@ def run_profile(my_profile_func):
def run_accuracy_check():
ref_rms_mlp = refRmsMlpOP()
fused_rms_mlp = fusedRmsMlpOP()
fused_rms_mlp_residual = fusedRmsMlpResOP()

golden_res = ref_rms_mlp()
fused_rms_res = fused_rms_mlp()
fused_rms_mlp_residual_res = fused_rms_mlp_residual()

print((fused_rms_res == golden_res).all())
print((fused_rms_res == fused_rms_mlp_residual_res).all())
print((ref_rms_mlp.residual == fused_rms_mlp_residual.residual).all())


def main():
5 changes: 5 additions & 0 deletions backends/intel_hpu/custom_ops/tests/test_rms_qkv_rope.py
Original file line number Diff line number Diff line change
@@ -90,6 +90,8 @@ def create_tensors(self):
self.src = paddle.rand(
[self.batch_size, self.seq_len, self.hidden_size], dtype=paddle.bfloat16
)
self.residual = paddle.zeros_like(self.src, dtype=paddle.bfloat16)

self.ln_scales = paddle.rand([self.hidden_size], dtype=paddle.bfloat16)
self.qkv_weights = paddle.rand(
[self.hidden_size * 3, self.hidden_size], dtype=paddle.float32
@@ -201,6 +203,7 @@ def run_test(self):
self.ln_scales,
self.qkv_weights,
self.new_rope.transpose([0, 1, 3, 2, 4]),
self.residual,
self.epsilon,
self.head_dim,
self.num_head,
@@ -217,6 +220,8 @@ def run_test(self):

# ===============summary==============
print(f"Test Pass for {self.test_name} testcase")
print((self.src == self.residual).all().item())
# print(self.residual.data_ptr() == residual.data_ptr())


class test_case_padding(TestFusedRmsQkvRope):
54 changes: 52 additions & 2 deletions backends/intel_hpu/kernels/hpu_funcs.h
Original file line number Diff line number Diff line change
@@ -294,6 +294,14 @@ class HpuFusedOperator : public HpuOperator {
inputs, outputs, params, guid, node_name);
}

template <typename T>
inline void AddNodeScatter(std::vector<synTensor> inputs,
std::vector<synTensor> outputs,
std::string node_name) {
std::string guid = "scatter_nd_onnx_fwd_" + guid_dtype<T>();
AddNode_IO(inputs, outputs, guid, node_name);
}

template <typename T>
inline void AddNodeSilu(std::vector<synTensor> inputs,
std::vector<synTensor> outputs,
@@ -302,13 +310,55 @@ class HpuFusedOperator : public HpuOperator {
AddNode_IO(inputs, outputs, guid, node_name);
}

template <typename T>
inline void AddNodeConcat(std::vector<synTensor> inputs,
std::vector<synTensor> outputs,
synConcatenateParams params,
std::string node_name) {
std::string guid = "concat";
AddNode_IOP<synConcatenateParams>(inputs, outputs, params, guid, node_name);
}

inline void AddNodeSplit(std::vector<synTensor> inputs,
std::vector<synTensor> outputs,
synSplitParams params,
std::string node_name) {
std::string guid = "split";
AddNode_IOP(inputs, outputs, params, guid, node_name);
AddNode_IOP<synSplitParams>(inputs, outputs, params, guid, node_name);
}

inline void AddNodeSlice(std::vector<synTensor> inputs,
std::vector<synTensor> outputs,
synSliceParamsV2 params,
std::string node_name) {
std::string guid = "slice";
AddNode_IOP<synSliceParamsV2>(inputs, outputs, params, guid, node_name);
}

inline void AddNodeSqueeze(std::vector<synTensor> inputs,
std::vector<synTensor> outputs,
synSqueezeParams params,
std::string node_name) {
std::string guid = "squeeze";
AddNode_IOP<synSqueezeParams>(inputs, outputs, params, guid, node_name);
}

template <typename T>
inline void AddNodeRmsNorm(std::vector<synTensor> inputs,
std::vector<synTensor> outputs,
ns_LayerNormKernel::Params params,
std::string node_name) {
std::string guid = "rms_norm_ex_fwd_" + guid_dtype<T>();
AddNode_IOP<ns_LayerNormKernel::Params>(
inputs, outputs, params, guid, node_name);
}

template <typename T>
inline void AddNodeRope(std::vector<synTensor> inputs,
std::vector<synTensor> outputs,
ns_RoPESt2::ParamsV2 params,
std::string node_name) {
std::string guid = "rotary_pos_embedding_fwd_" + guid_dtype<T>();
AddNode_IOP<ns_RoPESt2::ParamsV2>(inputs, outputs, params, guid, node_name);
}
};

2 changes: 1 addition & 1 deletion backends/intel_hpu/kernels/swiglu_kernel.cc
Original file line number Diff line number Diff line change
@@ -80,7 +80,7 @@ class SwiGlu : public HpuFusedOperator {
std::vector<synTensor> split_in = {cast_x};
std::vector<synTensor> split_out = {split_x, split_y};
std::string node_name = guid_ + "split";
AddNodeSplit<synSplitParams>(split_in, split_out, params, node_name);
AddNodeSplit(split_in, split_out, params, node_name);
} else {
split_x = cast_x;
split_y = cast_y;